aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
authorLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
committerLiam Mitchell <[email protected]>2026-03-09 19:06:36 -0700
commitd1abc50ee9d4fb72efc646e17decafea741caa34 (patch)
treee4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zenhttp
parentAllow requests with invalid content-types unless specified in command line or... (diff)
parentupdated chunk–block analyser (#818) (diff)
downloadzen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz
zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/auth/oidc.cpp24
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp323
-rw-r--r--src/zenhttp/clients/httpclientcommon.h115
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp579
-rw-r--r--src/zenhttp/clients/httpclientcpr.h15
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp566
-rw-r--r--src/zenhttp/httpclient.cpp451
-rw-r--r--src/zenhttp/httpclient_test.cpp1366
-rw-r--r--src/zenhttp/httpclientauth.cpp2
-rw-r--r--src/zenhttp/httpserver.cpp167
-rw-r--r--src/zenhttp/include/zenhttp/cprutils.h4
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h2
-rw-r--r--src/zenhttp/include/zenhttp/httpapiservice.h1
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h53
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h7
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h147
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h47
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h79
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h2
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurity.h38
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h51
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h65
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp195
-rw-r--r--src/zenhttp/packageformat.cpp6
-rw-r--r--src/zenhttp/security/passwordsecurity.cpp176
-rw-r--r--src/zenhttp/security/passwordsecurityfilter.cpp56
-rw-r--r--src/zenhttp/servers/httpasio.cpp427
-rw-r--r--src/zenhttp/servers/httpasio.h2
-rw-r--r--src/zenhttp/servers/httpmulti.cpp31
-rw-r--r--src/zenhttp/servers/httpmulti.h12
-rw-r--r--src/zenhttp/servers/httpnull.cpp18
-rw-r--r--src/zenhttp/servers/httpnull.h1
-rw-r--r--src/zenhttp/servers/httpparser.cpp155
-rw-r--r--src/zenhttp/servers/httpparser.h12
-rw-r--r--src/zenhttp/servers/httpplugin.cpp140
-rw-r--r--src/zenhttp/servers/httpsys.cpp556
-rw-r--r--src/zenhttp/servers/httpsys_iocontext.h40
-rw-r--r--src/zenhttp/servers/httptracer.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp311
-rw-r--r--src/zenhttp/servers/wsasio.h77
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp236
-rw-r--r--src/zenhttp/servers/wsframecodec.h74
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp485
-rw-r--r--src/zenhttp/servers/wshttpsys.h107
-rw-r--r--src/zenhttp/servers/wstest.cpp925
-rw-r--r--src/zenhttp/transports/dlltransport.cpp38
-rw-r--r--src/zenhttp/transports/winsocktransport.cpp2
-rw-r--r--src/zenhttp/xmake.lua1
-rw-r--r--src/zenhttp/zenhttp.cpp4
49 files changed, 7522 insertions, 673 deletions
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp
index 38e7586ad..23bbc17e8 100644
--- a/src/zenhttp/auth/oidc.cpp
+++ b/src/zenhttp/auth/oidc.cpp
@@ -32,6 +32,25 @@ namespace details {
using namespace std::literals;
+static std::string
+FormUrlEncode(std::string_view Input)
+{
+ std::string Result;
+ Result.reserve(Input.size());
+ for (char C : Input)
+ {
+ if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~')
+ {
+ Result.push_back(C);
+ }
+ else
+ {
+ Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C)));
+ }
+ }
+ return Result;
+}
+
OidcClient::OidcClient(const OidcClient::Options& Options)
{
m_BaseUrl = std::string(Options.BaseUrl);
@@ -67,6 +86,8 @@ OidcClient::Initialize()
.TokenEndpoint = Json["token_endpoint"].string_value(),
.UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
.RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .EndSessionEndpoint = Json["end_session_endpoint"].string_value(),
+ .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(),
.JwksUri = Json["jwks_uri"].string_value(),
.SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
.SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
@@ -81,7 +102,8 @@ OidcClient::Initialize()
OidcClient::RefreshTokenResult
OidcClient::RefreshToken(std::string_view RefreshToken)
{
- const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+ const std::string Body =
+ fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId));
HttpClient Http{m_Config.TokenEndpoint};
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 47425e014..6f4c67dd0 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -142,7 +142,10 @@ namespace detail {
DataSize -= CopySize;
if (m_CacheBufferOffset == CacheBufferSize)
{
- AppendData(m_CacheBuffer, CacheBufferSize);
+ if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize))
+ {
+ return Ec;
+ }
if (DataSize > 0)
{
ZEN_ASSERT(DataSize < CacheBufferSize);
@@ -382,6 +385,177 @@ namespace detail {
return Result;
}
+ MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {}
+
+ bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue)
+ {
+ std::string LowerCaseValue = ToLower(ContentTypeHeaderValue);
+ if (LowerCaseValue.starts_with("multipart/byteranges"))
+ {
+ size_t BoundaryPos = LowerCaseValue.find("boundary=");
+ if (BoundaryPos != std::string::npos)
+ {
+ // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string
+ std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9);
+ size_t BoundaryEnd = std::string::npos;
+ while (!BoundaryName.empty() && BoundaryName[0] == ' ')
+ {
+ BoundaryName = BoundaryName.substr(1);
+ }
+ if (!BoundaryName.empty())
+ {
+ if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"')
+ {
+ BoundaryEnd = BoundaryName.find('"', 1);
+ if (BoundaryEnd != std::string::npos)
+ {
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1)));
+ return true;
+ }
+ }
+ else
+ {
+ BoundaryEnd = BoundaryName.find_first_of(" \r\n");
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd)));
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ void MultipartBoundaryParser::ParseInput(std::string_view data)
+ {
+ const char* InputPtr = data.data();
+ size_t InputLength = data.length();
+ size_t ScanPos = 0;
+ while (ScanPos < InputLength)
+ {
+ const char ScanChar = InputPtr[ScanPos];
+ if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length()))
+ {
+ BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ break;
+ }
+ }
+
+ BoundaryHeader.Append(ScanChar);
+
+ HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+
+ if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset();
+ const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset();
+ const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset;
+ std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength));
+
+ uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1;
+
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType = HttpContentType::kBinary;
+
+ ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) {
+ const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line);
+ const std::string_view Key = KeyAndValue.first;
+ const std::string_view Value = KeyAndValue.second;
+ if (Key == "Content-Range")
+ {
+ std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value);
+ if (ContentRange.second != 0)
+ {
+ RangeOffset = ContentRange.first;
+ RangeLength = ContentRange.second;
+ }
+ }
+ else if (Key == "Content-Type")
+ {
+ ContentType = ParseContentType(Value);
+ }
+
+ return true;
+ });
+
+ if (RangeLength > 0)
+ {
+ Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload,
+ .RangeOffset = RangeOffset,
+ .RangeLength = RangeLength,
+ .ContentType = ContentType});
+ }
+
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ }
+ }
+ else
+ {
+ BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ }
+ ScanPos++;
+ }
+ PayloadOffset += InputLength;
+ }
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString)
+ {
+ size_t DelimiterPos = HeaderString.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string_view Key = HeaderString.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string_view Value = HeaderString.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(HeaderString, std::string_view{});
+ }
+
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value)
+ {
+ if (Value.starts_with("bytes "))
+ {
+ size_t RangeSplitPos = Value.find('-', 6);
+ if (RangeSplitPos != std::string::npos)
+ {
+ size_t RangeEndLength = Value.find('/', RangeSplitPos + 1);
+ if (RangeEndLength == std::string::npos)
+ {
+ RangeEndLength = Value.length() - (RangeSplitPos + 1);
+ }
+ else
+ {
+ RangeEndLength = RangeEndLength - (RangeSplitPos + 1);
+ }
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ uint64_t RangeOffset = RequestedRangeStart.value();
+ uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1;
+ return std::make_pair(RangeOffset, RangeLength);
+ }
+ }
+ }
+ return {0, 0};
+ }
+
} // namespace detail
} // namespace zen
@@ -423,6 +597,8 @@ namespace testutil {
} // namespace testutil
+TEST_SUITE_BEGIN("http.httpclientcommon");
+
TEST_CASE("BufferedReadFileStream")
{
ScopedTemporaryDirectory TmpDir;
@@ -470,5 +646,150 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("MultipartBoundaryParser")
+{
+ uint64_t Range1Offset = 2638;
+ uint64_t Range1Length = (5111437 - Range1Offset) + 1;
+
+ uint64_t Range2Offset = 5118199;
+ uint64_t Range2Length = (9147741 - Range2Offset) + 1;
+
+ std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229";
+ std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\"";
+
+ {
+ std::string_view Example1 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/44369878\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample1;
+ ParserExample1.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 7;
+ for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow)
+ {
+ ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample1.Boundaries.size() == 2);
+
+ CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example2 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample2;
+ ParserExample2.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow));
+ ParserExample2.ParseInput(Window);
+ }
+
+ CHECK(ParserExample2.Boundaries.size() == 2);
+
+ CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example3 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita";
+
+ detail::MultipartBoundaryParser ParserExample3;
+ ParserExample3.Init(ContentTypeHeaderValue2);
+
+ const size_t InputWindow = 31;
+ for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow)
+ {
+ ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample3.Boundaries.size() == 2);
+
+ CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example4 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "Not: really\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--000000000bait0019229\r\n"
+ "\r\n--00\r\n--000000000bait001922\r\n"
+ "\r\n\r\n\r\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "Content-Type: application/x-ue-comp\r\n"
+ "ditaditadita"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n---\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample4;
+ ParserExample4.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow));
+ ParserExample4.ParseInput(Window);
+ }
+
+ CHECK(ParserExample4.Boundaries.size() == 2);
+
+ CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length);
+ }
+}
+
+TEST_SUITE_END();
+
} // namespace zen
#endif
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index 1d0b7f9ea..5ed946541 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhttp/httpclient.h>
@@ -87,7 +88,7 @@ namespace detail {
std::error_code Write(std::string_view DataString);
IoBuffer DetachToIoBuffer();
IoBuffer BorrowIoBuffer();
- inline uint64_t GetSize() const { return m_WriteOffset; }
+ inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; }
void ResetWritePos(uint64_t WriteOffset);
private:
@@ -143,6 +144,118 @@ namespace detail {
uint64_t m_BytesLeftInSegment;
};
+ class IncrementalStringMatcher
+ {
+ public:
+ enum class EMatchState
+ {
+ None,
+ Partial,
+ Complete
+ };
+
+ EMatchState MatchState = EMatchState::None;
+
+ IncrementalStringMatcher() {}
+
+ IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString))
+ {
+ RawMatchString = MatchString.data();
+ }
+
+ void Init(std::string&& InMatchString)
+ {
+ MatchString = std::move(InMatchString);
+ RawMatchString = MatchString.data();
+ }
+
+ inline void Reset()
+ {
+ MatchLength = 0;
+ MatchStartOffset = 0;
+ MatchState = EMatchState::None;
+ }
+
+ inline uint64_t GetMatchEndOffset() const
+ {
+ if (MatchState == EMatchState::Complete)
+ {
+ return MatchStartOffset + MatchString.length();
+ }
+ return 0;
+ }
+
+ inline uint64_t GetMatchStartOffset() const
+ {
+ ZEN_ASSERT(MatchState == EMatchState::Complete);
+ return MatchStartOffset;
+ }
+
+ void Match(uint64_t Offset, char C)
+ {
+ ZEN_ASSERT_SLOW(RawMatchString != nullptr);
+
+ if (MatchState == EMatchState::Complete)
+ {
+ Reset();
+ }
+ if (C == RawMatchString[MatchLength])
+ {
+ if (MatchLength == 0)
+ {
+ MatchStartOffset = Offset;
+ }
+ MatchLength++;
+ if (MatchLength == MatchString.length())
+ {
+ MatchState = EMatchState::Complete;
+ }
+ else
+ {
+ MatchState = EMatchState::Partial;
+ }
+ }
+ else if (MatchLength != 0)
+ {
+ Reset();
+ Match(Offset, C);
+ }
+ else
+ {
+ Reset();
+ }
+ }
+ inline const std::string& GetMatchString() const { return MatchString; }
+
+ private:
+ std::string MatchString;
+ const char* RawMatchString = nullptr;
+ uint64_t MatchLength = 0;
+
+ uint64_t MatchStartOffset = 0;
+ };
+
+ class MultipartBoundaryParser
+ {
+ public:
+ std::vector<HttpClient::Response::MultipartBoundary> Boundaries;
+
+ MultipartBoundaryParser();
+ bool Init(const std::string_view ContentTypeHeaderValue);
+ void ParseInput(std::string_view data);
+
+ private:
+ IncrementalStringMatcher BoundaryBeginMatcher;
+ IncrementalStringMatcher BoundaryEndMatcher;
+ IncrementalStringMatcher HeaderEndMatcher;
+
+ ExtendableStringBuilder<64> BoundaryHeader;
+ uint64_t PayloadOffset = 0;
+ };
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString);
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value);
+
} // namespace detail
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 5d92b3b6b..14e40b02a 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -12,6 +12,7 @@
#include <zencore/session.h>
#include <zencore/stream.h>
#include <zenhttp/packageformat.h>
+#include <algorithm>
namespace zen {
@@ -23,6 +24,21 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+bool
+HttpClient::ErrorContext::IsConnectionError() const
+{
+ switch (static_cast<cpr::ErrorCode>(ErrorCode))
+ {
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return true;
+ default:
+ return false;
+ }
+}
+
// If we want to support different HTTP client implementations then we'll need to make this more abstract
HttpClientError::ResponseClass
@@ -149,6 +165,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri,
{
}
+bool
+CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ // Quiet
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
CprHttpClient::~CprHttpClient()
{
ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
@@ -162,10 +190,11 @@ CprHttpClient::~CprHttpClient()
}
HttpClient::Response
-CprHttpClient::ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload)
+CprHttpClient::ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
// This ends up doing a memcpy, would be good to get rid of it by streaming results
// into buffer directly
@@ -174,30 +203,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId,
if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
{
const HttpContentType ContentType = ParseContentType(It->second);
-
ResponseBuffer.SetContentType(ContentType);
}
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
-
- if (!Quiet)
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
{
- if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ if (ShouldLogErrorCode(WorkResponseCode))
{
ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
}
}
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
return HttpClient::Response{.StatusCode = WorkResponseCode,
.ResponsePayload = std::move(ResponseBuffer),
.Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Ranges = std::move(BoundaryPositions)};
}
HttpClient::Response
-CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload)
+CprHttpClient::CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
if (HttpResponse.error)
@@ -235,7 +271,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
}
else
{
- return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload));
+ return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
}
}
@@ -346,8 +382,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -385,8 +420,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -621,7 +655,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke
ResponseBuffer.SetContentType(ContentType);
}
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)};
}
//////////////////////////////////////////////////////////////////////////
@@ -896,236 +930,287 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF
std::string PayloadString;
std::unique_ptr<detail::TempPayloadFile> PayloadFile;
- cpr::Response Response = DoWithRetry(
- m_SessionId,
- [&]() {
- auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
- size_t DelimiterPos = header.find(':');
- if (DelimiterPos != std::string::npos)
- {
- std::string Key = header.substr(0, DelimiterPos);
- constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
- Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
- Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
-
- std::string Value = header.substr(DelimiterPos + 1);
- Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
- Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
-
- return std::make_pair(Key, Value);
- }
- return std::make_pair(header, "");
- };
-
- auto DownloadCallback = [&](std::string data, intptr_t) {
- if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
- {
- return false;
- }
- if (PayloadFile)
- {
- ZEN_ASSERT(PayloadString.empty());
- std::error_code Ec = PayloadFile->Write(data);
- if (Ec)
- {
- ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- return false;
- }
- }
- else
- {
- PayloadString.append(data);
- }
- return true;
- };
-
- uint64_t RequestedContentLength = (uint64_t)-1;
- if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
- {
- if (RangeIt->second.starts_with("bytes"))
- {
- size_t RangeStartPos = RangeIt->second.find('=', 5);
- if (RangeStartPos != std::string::npos)
- {
- RangeStartPos++;
- size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
- if (RangeSplitPos != std::string::npos)
- {
- std::optional<size_t> RequestedRangeStart =
- ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
- std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
- if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
- {
- RequestedContentLength = RequestedRangeEnd.value() - 1;
- }
- }
- }
- }
- }
-
- cpr::Response Response;
- {
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (Header.first == "Content-Length"sv)
- {
- std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
- if (ContentLength.has_value())
- {
- if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
- {
- PayloadFile = std::make_unique<detail::TempPayloadFile>();
- std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
- if (Ec)
- {
- ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- PayloadFile.reset();
- }
- }
- else
- {
- PayloadString.reserve(ContentLength.value());
- }
- }
- }
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
- return 1;
- };
-
- Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- }
- if (m_ConnectionSettings.AllowResume)
- {
- auto SupportsRanges = [](const cpr::Response& Response) -> bool {
- if (Response.header.find("Content-Range") != Response.header.end())
- {
- return true;
- }
- if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
- {
- return It->second == "bytes"sv;
- }
- return false;
- };
-
- auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
- if (ShouldRetry(Response))
- {
- return SupportsRanges(Response);
- }
- return false;
- };
-
- if (ShouldResume(Response))
- {
- auto It = Response.header.find("Content-Length");
- if (It != Response.header.end())
- {
- uint64_t ContentLength = RequestedContentLength;
- if (ContentLength == uint64_t(-1))
- {
- if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
- {
- ContentLength = ParsedContentLength.value();
- }
- }
-
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
-
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
-
- if (Header.first == "Content-Range"sv)
- {
- if (Header.second.starts_with("bytes "sv))
- {
- size_t RangeStartEnd = Header.second.find('-', 6);
- if (RangeStartEnd != std::string::npos)
- {
- const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
- if (Start)
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
- if (Start.value() == DownloadedSize)
- {
- return 1;
- }
- else if (Start.value() > DownloadedSize)
- {
- return 0;
- }
- if (PayloadFile)
- {
- PayloadFile->ResetWritePos(Start.value());
- }
- else
- {
- PayloadString = PayloadString.substr(0, Start.value());
- }
- return 1;
- }
- }
- }
- return 0;
- }
- return 1;
- };
-
- KeyValueMap HeadersWithRange(AdditionalHeader);
- do
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
-
- std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
- if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
- {
- if (RangeIt->second == Range)
- {
- // If we didn't make any progress, abort
- break;
- }
- }
- HeadersWithRange.Entries.insert_or_assign("Range", Range);
-
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- ReceivedHeaders.clear();
- } while (ShouldResume(Response));
- }
- }
- }
-
- if (!PayloadString.empty())
- {
- Response.text = std::move(PayloadString);
- }
- return Response;
- },
- PayloadFile);
-
- return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (IsMultiRangeResponse)
+ {
+ BoundaryParser.ParseInput(data);
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (Header.first == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (Header.first == "Content-Type")
+ {
+ IsMultiRangeResponse = BoundaryParser.Init(Header.second);
+ if (!IsMultiRangeResponse)
+ {
+ ContentType = ParseContentType(Header.second);
+ }
+ }
+ else if (Header.first == "Content-Range")
+ {
+ if (!IsMultiRangeResponse)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second);
+ if (Range.second != 0)
+ {
+ BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = ContentType});
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (Header.first == "Content-Range"sv)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Response),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
}
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 40af53b5d..752d91add 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -155,14 +155,19 @@ private:
std::function<cpr::Response()>&& Func,
std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; });
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
- HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload);
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
- HttpClient::Response ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload);
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
};
} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..9497dadb8
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,566 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ if (m_Socket)
+ {
+ asio::error_code Ec;
+ m_Socket->close(Ec);
+ }
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext);
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ if (m_Socket)
+ {
+ asio::error_code CloseEc;
+ m_Socket->close(CloseEc);
+ }
+ }
+ });
+
+ asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ asio::async_write(*m_Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::unique_ptr<asio::io_context::work> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 43e9fb468..281d512cf 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -21,9 +21,17 @@
#include "clients/httpclientcommon.h"
+#include <numeric>
+
#if ZEN_WITH_TESTS
+# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zenhttp/security/passwordsecurityfilter.h>
+# include "servers/httpasio.h"
+# include "servers/httpsys.h"
+
+# include <thread>
#endif // ZEN_WITH_TESTS
namespace zen {
@@ -96,6 +104,44 @@ HttpClientBase::GetAccessToken()
//////////////////////////////////////////////////////////////////////////
+std::vector<std::pair<uint64_t, uint64_t>>
+HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const
+{
+ if (Ranges.empty())
+ {
+ return {};
+ }
+
+ std::vector<std::pair<uint64_t, uint64_t>> Result;
+ Result.reserve(OffsetAndLengthPairs.size());
+
+ auto BoundaryIt = Ranges.begin();
+ auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin();
+ while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end())
+ {
+ uint64_t Offset = OffsetAndLengthPairIt->first;
+ uint64_t Length = OffsetAndLengthPairIt->second;
+ while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength)
+ {
+ BoundaryIt++;
+ if (BoundaryIt == Ranges.end())
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ }
+ if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset)
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset;
+ uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange;
+ Result.emplace_back(std::make_pair(RangePayloadOffset, Length));
+
+ OffsetAndLengthPairIt++;
+ }
+ return Result;
+}
+
CbObject
HttpClient::Response::AsObject() const
{
@@ -334,10 +380,55 @@ HttpClient::Authenticate()
return m_Inner->Authenticate();
}
+LatencyTestResult
+MeasureLatency(HttpClient& Client, std::string_view Url)
+{
+ std::vector<double> MeasurementTimes;
+ std::string ErrorMessage;
+
+ for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++)
+ {
+ HttpClient::Response MeasureResponse = Client.Get(Url);
+ if (MeasureResponse.IsSuccess())
+ {
+ MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds);
+ Sleep(5);
+ }
+ else
+ {
+ ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url));
+
+ // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable.
+ // Bail out immediately — retrying will just burn the connect timeout each time.
+ if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError())
+ {
+ break;
+ }
+ }
+ }
+
+ if (MeasurementTimes.empty())
+ {
+ return {.Success = false, .FailureReason = ErrorMessage};
+ }
+
+ if (MeasurementTimes.size() > 2)
+ {
+ std::sort(MeasurementTimes.begin(), MeasurementTimes.end());
+ MeasurementTimes.pop_back(); // Remove the worst time
+ }
+
+ double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size();
+
+ return {.Success = true, .LatencySeconds = AverageLatency};
+}
+
//////////////////////////////////////////////////////////////////////////
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpclient");
+
TEST_CASE("responseformat")
{
using namespace std::literals;
@@ -388,8 +479,366 @@ TEST_CASE("httpclient")
{
using namespace std::literals;
- SUBCASE("client") {}
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ if (HttpServiceRequest.IsLocalMachineRequest())
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+ else
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger");
+ }
+ }
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ SUBCASE("asio")
+ {
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ {
+ HttpClient Client(fmt::format("127.0.0.1:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ if (IsIPv6Capable())
+ {
+ HttpClient Client(fmt::format("[::1]:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ {
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+# if 0
+ {
+ HttpClient Client(fmt::format("10.24.101.77:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+ Sleep(20000);
+# endif // 0
+ AsioServer->RequestExit();
+ }
+ }
+
+# if ZEN_PLATFORM_WINDOWS
+ SUBCASE("httpsys")
+ {
+ Ref<HttpServer> HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false});
+
+ int Port = HttpSysServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ HttpSysServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { HttpSysServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ HttpSysServer->Close();
+ });
+
+ if (true)
+ {
+ HttpClient Client(fmt::format("127.0.0.1:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ if (IsIPv6Capable())
+ {
+ HttpClient Client(fmt::format("[::1]:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+
+ {
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+# if 0
+ {
+ HttpClient Client(fmt::format("10.24.101.77:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response TestResponse = Client.Get("/test/yo");
+ CHECK(TestResponse.IsSuccess());
+ CHECK_EQ(TestResponse.AsText(), "hey family");
+ }
+ Sleep(20000);
+# endif // 0
+ HttpSysServer->RequestExit();
+ }
+ }
+# endif // ZEN_PLATFORM_WINDOWS
+}
+
+TEST_CASE("httpclient.requestfilter")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ class MyFilterImpl : public IHttpRequestFilter
+ {
+ public:
+ virtual Result FilterRequest(HttpServerRequest& Request)
+ {
+ if (Request.RelativeUri() == "should_filter")
+ {
+ Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you");
+ return Result::ResponseSent;
+ }
+ else if (Request.RelativeUri() == "should_forbid")
+ {
+ return Result::Forbidden;
+ }
+ return Result::Accepted;
+ }
+ };
+
+ MyFilterImpl MyFilter;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response YoResponse = Client.Get("/test/yo");
+ CHECK(YoResponse.IsSuccess());
+ CHECK_EQ(YoResponse.AsText(), "hey family");
+
+ HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter");
+ CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed);
+ CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you");
+
+ HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid");
+ CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ AsioServer->RequestExit();
+ }
+}
+
+TEST_CASE("httpclient.password")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ SUBCASE("usernamepassword")
+ {
+ CbObjectWriter Writer;
+ {
+ Writer.BeginObject("basic");
+ {
+ Writer << "username"sv
+ << "me";
+ Writer << "password"sv
+ << "456123789";
+ }
+ Writer.EndObject();
+ Writer << "protect-machine-local-requests" << true;
+ }
+
+ PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save());
+
+ PasswordHttpFilter MyFilter(PasswordFilterOptions);
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response ForbiddenResponse = Client.Get("/test/yo");
+ CHECK(!ForbiddenResponse.IsSuccess());
+ CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ HttpClient::Response WithBasicResponse =
+ Client.Get("/test/yo",
+ std::pair<std::string, std::string>("Authorization",
+ fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password)));
+ CHECK(WithBasicResponse.IsSuccess());
+ AsioServer->SetHttpRequestFilter(nullptr);
+ }
+ AsioServer->RequestExit();
+ }
}
+TEST_SUITE_END();
void
httpclient_forcelink()
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
new file mode 100644
index 000000000..52bf149a7
--- /dev/null
+++ b/src/zenhttp/httpclient_test.cpp
@@ -0,0 +1,1366 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpserver.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryutil.h>
+# include <zencore/compositebuffer.h>
+# include <zencore/iobuffer.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/session.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include "servers/httpasio.h"
+
+# include <atomic>
+# include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+// Test service
+
+class HttpClientTestService : public HttpService
+{
+public:
+ HttpClientTestService()
+ {
+ m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool {
+ for (char C : Str)
+ {
+ if (C < '0' || C > '9')
+ {
+ return false;
+ }
+ }
+ return !Str.empty();
+ });
+
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ IoBuffer Body = HttpReq.ReadPayload();
+ HttpContentType CT = HttpReq.RequestContentType();
+ HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body);
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "echo/headers",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ CbObjectWriter Writer;
+ if (!Auth.empty())
+ {
+ Writer.AddString("Authorization", Auth);
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "echo/method",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Method = ToString(HttpReq.RequestVerb());
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "json",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Obj.AddString("message", "test");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "nocontent",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "created",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created");
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "content-type/text",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/json",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}");
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/binary",
+ [](HttpRouterRequest& Req) {
+ uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF};
+ IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data));
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/cbobject",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddString("type", "cbobject");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "auth/bearer",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ if (Auth.starts_with("Bearer ") && Auth.size() > 7)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated");
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized");
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "slow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Sleep(2000);
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response");
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "large",
+ [](HttpRouterRequest& Req) {
+ constexpr size_t Size = 64 * 1024;
+ IoBuffer Buf(Size);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < Size; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "status/{statuscode}",
+ [](HttpRouterRequest& Req) {
+ std::string_view CodeStr = Req.GetCapture(1);
+ int Code = std::stoi(std::string{CodeStr});
+ const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code);
+ Req.ServerRequest().WriteResponse(ResponseCode);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "attempt-counter",
+ [this](HttpRouterRequest& Req) {
+ uint32_t Count = m_AttemptCounter.fetch_add(1);
+ if (Count < m_FailCount)
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable);
+ }
+ else
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries");
+ }
+ },
+ HttpVerb::kGet);
+ }
+
+ virtual const char* BaseUri() const override { return "/api/test/"; }
+ virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); }
+
+ void ResetAttemptCounter(uint32_t FailCount)
+ {
+ m_AttemptCounter.store(0);
+ m_FailCount = FailCount;
+ }
+
+private:
+ HttpRequestRouter m_Router;
+ std::atomic<uint32_t> m_AttemptCounter{0};
+ uint32_t m_FailCount = 2;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Test server fixture
+
+struct TestServerFixture
+{
+ HttpClientTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ int Port = -1;
+
+ TestServerFixture()
+ {
+ Server = CreateHttpAsioServer(AsioConfig{});
+ Port = Server->Initialize(7600, TmpDir.Path());
+ ZEN_ASSERT(Port != -1);
+ Server->RegisterService(TestService);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~TestServerFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+TEST_SUITE_BEGIN("http.httpclient");
+
+TEST_CASE("httpclient.verbs")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("GET returns 200 with expected body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "GET");
+ }
+
+ SUBCASE("POST dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "POST");
+ }
+
+ SUBCASE("PUT dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "PUT");
+ }
+
+ SUBCASE("DELETE dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Delete("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "DELETE");
+ }
+
+ SUBCASE("HEAD returns 200 with empty body")
+ {
+ HttpClient::Response Resp = Client.Head("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), ""sv);
+ }
+}
+
+TEST_CASE("httpclient.get")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("simple GET with text response")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("GET with auth header via echo")
+ {
+ HttpClient::Response Resp =
+ Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123"));
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123");
+ }
+
+ SUBCASE("GET returning CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("GET large payload")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+
+ const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData());
+ bool Valid = true;
+ for (size_t i = 0; i < 64 * 1024; ++i)
+ {
+ if (Data[i] != static_cast<uint8_t>(i & 0xFF))
+ {
+ Valid = false;
+ break;
+ }
+ }
+ CHECK(Valid);
+ }
+}
+
+TEST_CASE("httpclient.post")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("POST with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "test payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "test payload data");
+ }
+
+ SUBCASE("POST with IoBuffer and explicit content type")
+ {
+ const char* Payload = "{\"key\":\"value\"}";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}");
+ }
+
+ SUBCASE("POST with CbObject payload round-trip")
+ {
+ CbObjectWriter Writer;
+ Writer.AddBool("enabled", true);
+ Writer.AddString("name", "testobj");
+ CbObject Obj = Writer.Save();
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Obj);
+ CHECK(Resp.IsSuccess());
+ CbObject RoundTripped = Resp.AsObject();
+ CHECK(RoundTripped["enabled"].AsBool() == true);
+ CHECK_EQ(RoundTripped["name"].AsString(), "testobj");
+ }
+
+ SUBCASE("POST with CompositeBuffer payload")
+ {
+ const char* Part1 = "hello ";
+ const char* Part2 = "composite";
+ IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1));
+ IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2));
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello composite");
+ }
+
+ SUBCASE("POST with custom headers")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{});
+ CHECK(Resp.IsSuccess());
+ }
+
+ SUBCASE("POST with empty body to nocontent endpoint")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+}
+
+TEST_CASE("httpclient.put")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("PUT with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "put payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "put payload data");
+ }
+
+ SUBCASE("PUT with parameters only")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+
+ SUBCASE("PUT to created endpoint")
+ {
+ const char* Payload = "new resource";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/created", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Resp.AsText(), "resource created");
+ }
+}
+
+TEST_CASE("httpclient.upload")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("Upload IoBuffer")
+ {
+ constexpr size_t Size = 128 * 1024;
+ IoBuffer Blob = CreateSemiRandomBlob(Size);
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), Size);
+ }
+
+ SUBCASE("Upload CompositeBuffer")
+ {
+ IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024);
+ IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024);
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.download")
+{
+ TestServerFixture Fixture;
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("Download small payload stays in memory")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill")
+ {
+ HttpClientSettings Settings;
+ Settings.MaximumInMemoryDownloadSize = 4;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.status-codes")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("2xx are success")
+ {
+ CHECK(Client.Get("/api/test/status/200").IsSuccess());
+ CHECK(Client.Get("/api/test/status/201").IsSuccess());
+ CHECK(Client.Get("/api/test/status/204").IsSuccess());
+ }
+
+ SUBCASE("4xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/400").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/401").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/403").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/404").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/409").IsSuccess());
+ }
+
+ SUBCASE("5xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/500").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/502").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/503").IsSuccess());
+ }
+
+ SUBCASE("status code values match")
+ {
+ CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent);
+ CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest);
+ CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized);
+ CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden);
+ CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound);
+ CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict);
+ CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError);
+ CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway);
+ CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.response")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("IsSuccess and operator bool for success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(static_cast<bool>(Resp));
+ }
+
+ SUBCASE("IsSuccess and operator bool for failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ CHECK(!Resp.IsSuccess());
+ CHECK(!static_cast<bool>(Resp));
+ }
+
+ SUBCASE("AsText returns body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("AsText returns empty for no-content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/nocontent");
+ CHECK(Resp.AsText().empty());
+ }
+
+ SUBCASE("AsObject parses CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("AsObject returns empty for non-CB content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CbObject Obj = Resp.AsObject();
+ CHECK(!Obj);
+ }
+
+ SUBCASE("ToText for text content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK_EQ(Resp.ToText(), "plain text");
+ }
+
+ SUBCASE("ToText for CbObject content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ std::string Text = Resp.ToText();
+ CHECK(!Text.empty());
+ // ToText for CbObject converts to JSON string representation
+ CHECK(Text.find("ok") != std::string::npos);
+ CHECK(Text.find("test") != std::string::npos);
+ }
+
+ SUBCASE("ErrorMessage includes status code on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ std::string Msg = Resp.ErrorMessage("test-prefix");
+ CHECK(Msg.find("test-prefix") != std::string::npos);
+ CHECK(Msg.find("404") != std::string::npos);
+ }
+
+ SUBCASE("ThrowError throws on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/500");
+ CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError);
+ }
+
+ SUBCASE("ThrowError does not throw on success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_NOTHROW(Resp.ThrowError("test"));
+ }
+
+ SUBCASE("HttpClientError carries response code")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/403");
+ try
+ {
+ Resp.ThrowError("test");
+ CHECK(false); // should not reach
+ }
+ catch (const HttpClientError& Err)
+ {
+ CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden);
+ }
+ }
+}
+
+TEST_CASE("httpclient.error-handling")
+{
+ SUBCASE("Connection refused")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("Request timeout")
+ {
+ TestServerFixture Fixture;
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/slow");
+ CHECK(!Resp.IsSuccess());
+ }
+
+ SUBCASE("Nonexistent endpoint returns failure")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/does-not-exist");
+ CHECK(!Resp.IsSuccess());
+ }
+}
+
+TEST_CASE("httpclient.session")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Default session ID is non-empty")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.GetSessionId().empty());
+ }
+
+ SUBCASE("SetSessionId changes ID")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ std::string OldId = std::string(Client.GetSessionId());
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ CHECK_NE(Client.GetSessionId(), OldId);
+ }
+
+ SUBCASE("SetSessionId with Zero resets")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ Client.SetSessionId(Oid::Zero);
+ // After resetting, should get a session string (not empty, not the custom one)
+ CHECK(!Client.GetSessionId().empty());
+ CHECK_NE(Client.GetSessionId(), NewId.ToString());
+ }
+}
+
+TEST_CASE("httpclient.authentication")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Authenticate returns false without provider")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns true with valid token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "valid-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns false with expired token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "expired-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Bearer token verified by auth endpoint")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response AuthResp =
+ Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token"));
+ CHECK(AuthResp.IsSuccess());
+ CHECK_EQ(AuthResp.AsText(), "authenticated");
+ }
+
+ SUBCASE("Request without token to auth endpoint gets 401")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/auth/bearer");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized);
+ }
+}
+
+TEST_CASE("httpclient.content-types")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("text content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText);
+ }
+
+ SUBCASE("JSON content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/json");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON);
+ }
+
+ SUBCASE("binary content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/binary");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary);
+ }
+
+ SUBCASE("CbObject content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject);
+ }
+}
+
+TEST_CASE("httpclient.metadata")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("ElapsedSeconds is positive")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.ElapsedSeconds > 0.0);
+ }
+
+ SUBCASE("DownloadedBytes populated for GET")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.DownloadedBytes > 0);
+ }
+
+ SUBCASE("UploadedBytes populated for POST with payload")
+ {
+ const char* Payload = "some upload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.UploadedBytes > 0);
+ }
+}
+
+TEST_CASE("httpclient.retry")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Retry succeeds after transient failures")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "success after retries");
+ }
+
+ SUBCASE("No retry returns 503 immediately")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 0;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.measurelatency")
+{
+ SUBCASE("Successful measurement against live server")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(Result.Success);
+ CHECK(Result.LatencySeconds > 0.0);
+ }
+
+ SUBCASE("Failed measurement against unreachable port")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(!Result.Success);
+ CHECK(!Result.FailureReason.empty());
+ }
+}
+
+TEST_CASE("httpclient.keyvaluemap")
+{
+ SUBCASE("Default construction is empty")
+ {
+ HttpClient::KeyValueMap Map;
+ CHECK(Map->empty());
+ }
+
+ SUBCASE("Construction from pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value"));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from string_view pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from initializer list")
+ {
+ HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}});
+ CHECK_EQ(Map->size(), 2u);
+ CHECK_EQ(Map->at("a"), "1");
+ CHECK_EQ(Map->at("b"), "2");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Transport fault testing
+
+static std::string
+MakeRawHttpResponse(int StatusCode, std::string_view Body)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: text/plain\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ StatusCode,
+ Body.size(),
+ Body);
+}
+
+static std::string
+MakeRawHttpHeaders(int StatusCode, size_t ContentLength)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n",
+ StatusCode,
+ ContentLength);
+}
+
+static void
+DrainHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+}
+
+static void
+DrainFullHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ // Read until end of headers
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Extract headers to find Content-Length
+ std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+
+ size_t ContentLength = 0;
+ auto Pos = Headers.find("Content-Length: ");
+ if (Pos == std::string::npos)
+ {
+ Pos = Headers.find("content-length: ");
+ }
+ if (Pos != std::string::npos)
+ {
+ size_t ValStart = Pos + 16; // length of "Content-Length: "
+ size_t ValEnd = Headers.find("\r\n", ValStart);
+ if (ValEnd != std::string::npos)
+ {
+ ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart));
+ }
+ }
+
+ // Calculate how many body bytes were already read past the header boundary.
+ // asio::read_until may read past the delimiter, so Buf.data() contains everything read.
+ size_t HeaderEnd = Headers.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0;
+ size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0;
+
+ if (Remaining > 0)
+ {
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+static void
+DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead)
+{
+ // Read headers first
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Determine how many body bytes were already buffered past headers
+ std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+ size_t HeaderEnd = All.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0;
+
+ if (BodyBytesInBuf < BytesToRead)
+ {
+ size_t Remaining = BytesToRead - BodyBytesInBuf;
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+struct FaultTcpServer
+{
+ using FaultHandler = std::function<void(asio::ip::tcp::socket&)>;
+
+ asio::io_context m_IoContext;
+ asio::ip::tcp::acceptor m_Acceptor;
+ FaultHandler m_Handler;
+ std::thread m_Thread;
+ int m_Port;
+
+ explicit FaultTcpServer(FaultHandler Handler)
+ : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0))
+ , m_Handler(std::move(Handler))
+ {
+ m_Port = m_Acceptor.local_endpoint().port();
+ StartAccept();
+ m_Thread = std::thread([this]() { m_IoContext.run(); });
+ }
+
+ ~FaultTcpServer()
+ {
+ std::error_code Ec;
+ m_Acceptor.close(Ec);
+ m_IoContext.stop();
+ if (m_Thread.joinable())
+ {
+ m_Thread.join();
+ }
+ }
+
+ FaultTcpServer(const FaultTcpServer&) = delete;
+ FaultTcpServer& operator=(const FaultTcpServer&) = delete;
+
+ void StartAccept()
+ {
+ m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) {
+ if (!Ec)
+ {
+ m_Handler(Socket);
+ }
+ if (m_Acceptor.is_open())
+ {
+ StartAccept();
+ }
+ });
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+TEST_CASE("httpclient.transport-faults" * doctest::skip())
+{
+ SUBCASE("connection reset before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection closed before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("partial headers then close")
+ {
+ // libcurl parses the status line (200 OK) and accepts the response even though
+ // headers are truncated mid-field. It reports success with an empty body instead
+ // of an error. Ideally this should be detected as a transport failure.
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Partial = "HTTP/1.1 200 OK\r\nContent-";
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Partial), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ WARN(!Resp.IsSuccess());
+ WARN(Resp.Error.has_value());
+ }
+
+ SUBCASE("truncated body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(100, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection reset mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 10000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(1000, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("stalled response triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("retry succeeds after transient failures")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ DrainHttpRequest(Socket);
+ if (N < 2)
+ {
+ // Connection reset produces NETWORK_SEND_FAILURE which is retryable
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ std::string Response = MakeRawHttpResponse(200, "recovered");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "recovered");
+ }
+}
+
+TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
+{
+ constexpr size_t kPostBodySize = 256 * 1024;
+
+ auto MakePostBody = []() -> IoBuffer {
+ IoBuffer Buf(kPostBodySize);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < kPostBodySize; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Buf.SetContentType(ZenContentType::kBinary);
+ return Buf;
+ };
+
+ SUBCASE("POST: server resets before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server closes before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server resets mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainPartialBody(Socket, 8 * 1024);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: early error response before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(503, "service busy");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ // With a large upload body, the server may RST the connection before the client
+ // reads the 503 response. Either outcome is valid: the client sees the HTTP 503
+ // status, or it sees a transport-level error from the RST.
+ CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value()));
+ }
+
+ SUBCASE("POST: stalled upload triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ // Stop reading body — TCP window will fill and client send will stall
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(2000);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("POST: retry with large body after transient failure")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ if (N < 2)
+ {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ DrainFullHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "upload-ok");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "upload-ok");
+ }
+}
+
+TEST_SUITE_END();
+
+void
+httpclient_test_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index 72df12d02..02e1b57e2 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -170,7 +170,7 @@ namespace zen { namespace httpclientauth {
time_t UTCTime = timegm(&Time);
HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime);
- ExpireTime += std::chrono::microseconds(Millisecond);
+ ExpireTime += std::chrono::milliseconds(Millisecond);
return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
}
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index c4e67d4ed..9bae95690 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -23,10 +23,12 @@
#include <zencore/logging.h>
#include <zencore/stream.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zenhttp/packageformat.h>
#include <zentelemetry/otlptrace.h>
+#include <zentelemetry/stats.h>
#include <charconv>
#include <mutex>
@@ -463,7 +465,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
//////////////////////////////////////////////////////////////////////////
-HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri())
+HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
{
}
@@ -745,6 +747,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand
{
if (UriPattern[i] == '}')
{
+ if (i == PatternStart)
+ {
+ throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern));
+ }
std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart);
if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end())
{
@@ -910,8 +916,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
CapturedSegments.emplace_back(Uri);
- for (int MatcherIndex : Matchers)
+ for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++)
{
+ int MatcherIndex = Matchers[MatcherOffset];
if (UriPos >= UriLen)
{
IsMatch = false;
@@ -921,9 +928,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (MatcherIndex < 0)
{
// Literal match
- int LitIndex = -MatcherIndex - 1;
- const std::string& LitStr = m_Literals[LitIndex];
- size_t LitLen = LitStr.length();
+ int LitIndex = -MatcherIndex - 1;
+ std::string_view LitStr = m_Literals[LitIndex];
+ size_t LitLen = LitStr.length();
if (Uri.substr(UriPos, LitLen) == LitStr)
{
@@ -939,9 +946,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
{
// Matcher function
size_t SegmentStart = UriPos;
- while (UriPos < UriLen && Uri[UriPos] != '/')
+
+ if (MatcherOffset == (Matchers.size() - 1))
+ {
+ // Last matcher, use the remaining part of the uri
+ UriPos = UriLen;
+ }
+ else
{
- ++UriPos;
+ while (UriPos < UriLen && Uri[UriPos] != '/')
+ {
+ ++UriPos;
+ }
}
std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart);
@@ -970,7 +986,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -994,7 +1010,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -1014,7 +1030,28 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
int
HttpServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
- return OnInitialize(BasePort, std::move(DataDir));
+ m_EffectivePort = OnInitialize(BasePort, std::move(DataDir));
+ m_ExternalHost = OnGetExternalHost();
+ return m_EffectivePort;
+}
+
+std::string
+HttpServer::OnGetExternalHost() const
+{
+ return GetMachineName();
+}
+
+std::string
+HttpServer::GetServiceUri(const HttpService* Service) const
+{
+ if (Service)
+ {
+ return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri());
+ }
+ else
+ {
+ return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort);
+ }
}
void
@@ -1052,6 +1089,45 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba
}
}
+void
+HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ OnSetHttpRequestFilter(RequestFilter);
+}
+
+CbObject
+HttpServer::CollectStats()
+{
+ CbObjectWriter Cbo;
+
+ metrics::EmitSnapshot("requests", m_RequestMeter, Cbo);
+
+ Cbo.BeginObject("bytes");
+ {
+ Cbo << "received" << GetTotalBytesReceived();
+ Cbo << "sent" << GetTotalBytesSent();
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("websockets");
+ {
+ Cbo << "active_connections" << GetActiveWebSocketConnectionCount();
+ Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed);
+ Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed);
+ Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed);
+ Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed);
+ }
+ Cbo.EndObject();
+
+ return Cbo.Save();
+}
+
+void
+HttpServer::HandleStatsRequest(HttpServerRequest& Request)
+{
+ Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+}
+
//////////////////////////////////////////////////////////////////////////
HttpRpcHandler::HttpRpcHandler()
@@ -1294,6 +1370,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpserver");
+
TEST_CASE("http.common")
{
using namespace std::literals;
@@ -1310,7 +1388,11 @@ TEST_CASE("http.common")
{
TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; }
virtual IoBuffer ReadPayload() override { return IoBuffer(); }
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override
+
+ virtual bool IsLocalMachineRequest() const override { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override { return {}; }
+
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override
{
ZEN_UNUSED(ResponseCode, ContentType, Blobs);
}
@@ -1395,20 +1477,33 @@ TEST_CASE("http.common")
SUBCASE("router-matcher")
{
- bool HandledA = false;
- bool HandledAA = false;
- bool HandledAB = false;
- bool HandledAandB = false;
+ bool HandledA = false;
+ bool HandledAA = false;
+ bool HandledAB = false;
+ bool HandledAandB = false;
+ bool HandledAandPath = false;
std::vector<std::string> Captures;
auto Reset = [&] {
- HandledA = HandledAA = HandledAB = HandledAandB = false;
+ HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false;
Captures.clear();
};
TestHttpService Service;
HttpRequestRouter r;
- r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; });
- r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; });
+
+ r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; });
+ r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; });
+ static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); });
+
+ r.RegisterRoute(
+ "path/{a}/{path}",
+ [&](auto& Req) {
+ HandledAandPath = true;
+ Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
+ },
+ HttpVerb::kGet);
+
r.RegisterRoute(
"{a}",
[&](auto& Req) {
@@ -1437,7 +1532,6 @@ TEST_CASE("http.common")
Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
},
HttpVerb::kGet);
-
{
Reset();
TestHttpServerRequest req{Service, "ab"sv};
@@ -1445,6 +1539,7 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 1);
CHECK_EQ(Captures[0], "ab"sv);
@@ -1457,6 +1552,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1470,6 +1566,7 @@ TEST_CASE("http.common")
CHECK(!HandledAA);
CHECK(!HandledAB);
CHECK(HandledAandB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1482,6 +1579,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
}
{
@@ -1491,6 +1589,35 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 1);
+ CHECK_EQ(Captures[0], "a123"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "simple_path.txt"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "directory/and/path.txt"sv);
}
}
@@ -1508,6 +1635,8 @@ TEST_CASE("http.common")
}
}
+TEST_SUITE_END();
+
void
http_forcelink()
{
diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h
index a988346e0..c252a5d99 100644
--- a/src/zenhttp/include/zenhttp/cprutils.h
+++ b/src/zenhttp/include/zenhttp/cprutils.h
@@ -66,10 +66,10 @@ struct fmt::formatter<cpr::Response>
Response.url.str(),
Response.status_code,
zen::ToString(zen::HttpResponseCode(Response.status_code)),
+ Response.reason,
Response.uploaded_bytes,
Response.downloaded_bytes,
NiceResponseTime.c_str(),
- Response.reason,
Json);
}
else
@@ -82,10 +82,10 @@ struct fmt::formatter<cpr::Response>
Response.url.str(),
Response.status_code,
zen::ToString(zen::HttpResponseCode(Response.status_code)),
+ Response.reason,
Response.uploaded_bytes,
Response.downloaded_bytes,
NiceResponseTime.c_str(),
- Response.reason,
Body.GetText());
}
}
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index addb00cb8..57ab01158 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response>
if (Response.IsSuccess())
{
return fmt::format_to(Ctx.out(),
- "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
+ "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}",
ToString(Response.StatusCode),
Response.UploadedBytes,
Response.DownloadedBytes,
diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h
index 0270973bf..2d384d1d8 100644
--- a/src/zenhttp/include/zenhttp/httpapiservice.h
+++ b/src/zenhttp/include/zenhttp/httpapiservice.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include <zenhttp/httpserver.h>
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 9a9b74d72..1bb36a298 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -13,6 +13,7 @@
#include <functional>
#include <optional>
#include <unordered_map>
+#include <vector>
namespace zen {
@@ -58,6 +59,10 @@ struct HttpClientSettings
Oid SessionId = Oid::Zero;
bool Verbose = false;
uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u;
+
+ /// HTTP status codes that are expected and should not be logged as warnings.
+ /// 404 is always treated as expected regardless of this list.
+ std::vector<HttpResponseCode> ExpectedErrorCodes;
};
class HttpClientError : public std::runtime_error
@@ -113,6 +118,15 @@ private:
class HttpClientBase;
+/** HTTP Client
+ *
+ * This is safe for use on multiple threads simultaneously, as each
+ * instance maintains an internal connection pool and will synchronize
+ * access to it as needed.
+ *
+ * Uses libcurl under the hood. We currently only use HTTP 1.1 features.
+ *
+ */
class HttpClient
{
public:
@@ -123,8 +137,11 @@ public:
struct ErrorContext
{
- int ErrorCode;
+ int ErrorCode = 0;
std::string ErrorMessage;
+
+ /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */
+ bool IsConnectionError() const;
};
struct KeyValueMap
@@ -171,13 +188,29 @@ public:
KeyValueMap Header;
// The number of bytes sent as part of the request
- int64_t UploadedBytes;
+ int64_t UploadedBytes = 0;
// The number of bytes received as part of the response
- int64_t DownloadedBytes;
+ int64_t DownloadedBytes = 0;
// The elapsed time in seconds for the request to execute
- double ElapsedSeconds;
+ double ElapsedSeconds = 0.0;
+
+ struct MultipartBoundary
+ {
+ uint64_t OffsetInPayload = 0;
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType;
+ };
+
+ // Ranges will map out all received ranges, both single and multi-range responses
+ // If no range was requested Ranges will be empty
+ std::vector<MultipartBoundary> Ranges;
+
+ // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges).
+ // If the response was not a partial response, an empty vector will be returned
+ std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const;
// This contains any errors from the HTTP stack. It won't contain information on
// why the server responded with a non-success HTTP status, that may be gleaned
@@ -260,6 +293,16 @@ private:
const HttpClientSettings m_ConnectionSettings;
};
-void httpclient_forcelink(); // internal
+struct LatencyTestResult
+{
+ bool Success = false;
+ std::string FailureReason;
+ double LatencySeconds = -1.0;
+};
+
+LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url);
+
+void httpclient_forcelink(); // internal
+void httpclient_test_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index bc18549c9..8fca35ac5 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept
return IsHttpSuccessCode(int(HttpCode));
}
+[[nodiscard]] inline bool
+IsHttpOk(HttpResponseCode HttpCode) noexcept
+{
+ return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted ||
+ HttpCode == HttpResponseCode::NoContent;
+}
+
std::string_view ToString(HttpResponseCode HttpCode);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 3438a1471..0e1714669 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -13,6 +13,8 @@
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zentelemetry/stats.h>
+
#include <functional>
#include <gsl/gsl-lite.hpp>
#include <list>
@@ -30,16 +32,18 @@ class HttpService;
*/
class HttpServerRequest
{
-public:
+protected:
explicit HttpServerRequest(HttpService& Service);
+
+public:
~HttpServerRequest();
// Synchronous operations
[[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
- [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
+ [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
[[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
- [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix
+ [[nodiscard]] inline HttpService& Service() const { return m_Service; }
struct QueryParams
{
@@ -79,6 +83,18 @@ public:
inline bool IsHandled() const { return !!(m_Flags & kIsHandled); }
inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); }
inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; }
+ inline void SetLogRequest(bool ShouldLog)
+ {
+ if (ShouldLog)
+ {
+ m_Flags |= kLogRequest;
+ }
+ else
+ {
+ m_Flags &= ~kLogRequest;
+ }
+ }
+ inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); }
/** Read POST/PUT payload for request body, which is always available without delay
*/
@@ -87,6 +103,10 @@ public:
CbObject ReadPayloadObject();
CbPackage ReadPayloadPackage();
+ virtual bool IsLocalMachineRequest() const = 0;
+ virtual std::string_view GetAuthorizationHeader() const = 0;
+ virtual std::string_view GetRemoteAddress() const { return {}; }
+
/** Respond with payload
No data will have been sent when any of these functions return. Instead, the response will be transmitted
@@ -115,15 +135,17 @@ protected:
kSuppressBody = 1 << 1,
kHaveRequestId = 1 << 2,
kHaveSessionId = 1 << 3,
+ kLogRequest = 1 << 4,
};
- mutable uint32_t m_Flags = 0;
+ mutable uint32_t m_Flags = 0;
+
+ HttpService& m_Service; // Service handling this request
HttpVerb m_Verb = HttpVerb::kGet;
HttpContentType m_ContentType = HttpContentType::kBinary;
HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
uint64_t m_ContentLength = ~0ull;
- std::string_view m_BaseUri; // Base URI path of the service handling this request
- std::string_view m_Uri; // URI without service prefix
+ std::string_view m_Uri; // URI without service prefix
std::string_view m_UriWithExtension;
std::string_view m_QueryString;
mutable uint32_t m_RequestId = ~uint32_t(0);
@@ -144,6 +166,19 @@ public:
virtual void OnRequestComplete() = 0;
};
+class IHttpRequestFilter
+{
+public:
+ virtual ~IHttpRequestFilter() {}
+ enum class Result
+ {
+ Forbidden,
+ ResponseSent,
+ Accepted
+ };
+ virtual Result FilterRequest(HttpServerRequest& Request) = 0;
+};
+
/**
* Base class for implementing an HTTP "service"
*
@@ -170,30 +205,110 @@ private:
int m_UriPrefixLength = 0;
};
+struct IHttpStatsProvider
+{
+ /** Handle an HTTP stats request, writing the response directly.
+ * Implementations may inspect query parameters on the request
+ * to include optional detailed breakdowns.
+ */
+ virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
+
+ /** Return the provider's current stats as a CbObject snapshot.
+ * Used by the WebSocket push thread to broadcast live updates
+ * without requiring an HttpServerRequest. Providers that do
+ * not override this will be skipped in WebSocket broadcasts.
+ */
+ virtual CbObject CollectStats() { return {}; }
+};
+
+struct IHttpStatsService
+{
+ virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+ virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+};
+
/** HTTP server
*
* Implements the main event loop to service HTTP requests, and handles routing
* requests to the appropriate handler as registered via RegisterService
*/
-class HttpServer : public RefCounted
+class HttpServer : public RefCounted, public IHttpStatsProvider
{
public:
void RegisterService(HttpService& Service);
void EnumerateServices(std::function<void(HttpService&)>&& Callback);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
int Initialize(int BasePort, std::filesystem::path DataDir);
void Run(bool IsInteractiveSession);
void RequestExit();
void Close();
+ /** Returns a canonical http:// URI for the given service, using the external
+ * IP and the port the server is actually listening on. Only valid
+ * after Initialize() has returned successfully.
+ */
+ std::string GetServiceUri(const HttpService* Service) const;
+
+ /** Returns the external host string (IP or hostname) determined during Initialize().
+ * Only valid after Initialize() has returned successfully.
+ */
+ std::string_view GetExternalHost() const { return m_ExternalHost; }
+
+ /** Returns total bytes received and sent across all connections since server start. */
+ virtual uint64_t GetTotalBytesReceived() const { return 0; }
+ virtual uint64_t GetTotalBytesSent() const { return 0; }
+
+ /** Mark that a request has been handled. Called by server implementations. */
+ void MarkRequest() { m_RequestMeter.Mark(); }
+
+ /** Set a default redirect path for root requests */
+ void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; }
+
+ std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; }
+
+ /** Track active WebSocket connections — called by server implementations on upgrade/close. */
+ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); }
+ void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); }
+ uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); }
+
+ /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */
+ void OnWebSocketFrameReceived(uint64_t Bytes)
+ {
+ m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+ void OnWebSocketFrameSent(uint64_t Bytes)
+ {
+ m_WsFramesSent.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+
+ // IHttpStatsProvider
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+
private:
std::vector<HttpService*> m_KnownServices;
+ int m_EffectivePort = 0;
+ std::string m_ExternalHost;
+ metrics::Meter m_RequestMeter;
+ std::string m_DefaultRedirect;
+ std::atomic<uint64_t> m_ActiveWebSocketConnections{0};
+ std::atomic<uint64_t> m_WsFramesReceived{0};
+ std::atomic<uint64_t> m_WsFramesSent{0};
+ std::atomic<uint64_t> m_WsBytesReceived{0};
+ std::atomic<uint64_t> m_WsBytesSent{0};
virtual void OnRegisterService(HttpService& Service) = 0;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0;
virtual void OnRun(bool IsInteractiveSession) = 0;
virtual void OnRequestExit() = 0;
virtual void OnClose() = 0;
+
+protected:
+ virtual std::string OnGetExternalHost() const;
};
struct HttpServerPluginConfig
@@ -236,7 +351,7 @@ public:
inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
private:
- HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
+ explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
~HttpRouterRequest() = default;
HttpRouterRequest(const HttpRouterRequest&) = delete;
@@ -385,7 +500,7 @@ public:
~HttpRpcHandler();
HttpRpcHandler(const HttpRpcHandler&) = delete;
- HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
+ HttpRpcHandler& operator=(const HttpRpcHandler&) = delete;
void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
@@ -401,17 +516,7 @@ private:
bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-struct IHttpStatsProvider
-{
- virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
-};
-
-struct IHttpStatsService
-{
- virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
- virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
-};
-
-void http_forcelink(); // internal
+void http_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index e6fea6765..460315faf 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -3,23 +3,50 @@
#pragma once
#include <zencore/logging.h>
+#include <zencore/thread.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+#include <atomic>
#include <map>
+#include <memory>
+#include <thread>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+#include <asio/steady_timer.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-class HttpStatsService : public HttpService, public IHttpStatsService
+class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler
{
public:
- HttpStatsService();
+ /// Construct without an io_context — optionally uses a dedicated push thread
+ /// for WebSocket stats broadcasting.
+ explicit HttpStatsService(bool EnableWebSockets = false);
+
+ /// Construct with an external io_context — uses an asio timer instead
+ /// of a dedicated thread for WebSocket stats broadcasting.
+ /// The caller must ensure the io_context outlives this service and that
+ /// its run loop is active.
+ HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true);
+
~HttpStatsService();
+ void Shutdown();
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
private:
LoggerRef m_Log;
HttpRequestRouter m_Router;
@@ -28,6 +55,22 @@ private:
RwLock m_Lock;
std::map<std::string, IHttpStatsProvider*> m_Providers;
+
+ // WebSocket push
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::atomic<bool> m_PushEnabled{false};
+
+ void BroadcastStats();
+
+ // Thread-based push (when no io_context is provided)
+ std::thread m_PushThread;
+ Event m_PushEvent;
+ void PushThreadFunction();
+
+ // Timer-based push (when an io_context is provided)
+ std::unique_ptr<asio::steady_timer> m_PushTimer;
+ void EnqueuePushTimer();
};
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
new file mode 100644
index 000000000..926ec1e3d
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/websocket.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+
+namespace zen {
+
+/**
+ * Callback interface for WebSocket client events
+ *
+ * Separate from the server-side IWebSocketHandler because the caller
+ * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ */
+class IWsClientHandler
+{
+public:
+ virtual ~IWsClientHandler() = default;
+
+ virtual void OnWsOpen() = 0;
+ virtual void OnWsMessage(const WebSocketMessage& Msg) = 0;
+ virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0;
+};
+
+struct HttpWsClientSettings
+{
+ std::string LogCategory = "wsclient";
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+};
+
+/**
+ * WebSocket client over TCP (ws:// scheme)
+ *
+ * Uses ASIO for async I/O. Two construction modes:
+ * - Internal io_context + background thread (standalone use)
+ * - External io_context (shared event loop, no internal thread)
+ *
+ * Thread-safe for SendText/SendBinary/Close.
+ */
+class HttpWsClient
+{
+public:
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {});
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {});
+
+ ~HttpWsClient();
+
+ HttpWsClient(const HttpWsClient&) = delete;
+ HttpWsClient& operator=(const HttpWsClient&) = delete;
+
+ void Connect();
+ void SendText(std::string_view Text);
+ void SendBinary(std::span<const uint8_t> Data);
+ void Close(uint16_t Code = 1000, std::string_view Reason = {});
+ bool IsOpen() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index c90b840da..1a5068580 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -68,7 +68,7 @@ struct CbAttachmentEntry
struct CbAttachmentReferenceHeader
{
uint64_t PayloadByteOffset = 0;
- uint64_t PayloadByteSize = ~0u;
+ uint64_t PayloadByteSize = ~uint64_t(0);
uint16_t AbsolutePathLength = 0;
// This header will be followed by UTF8 encoded absolute path to backing file
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
new file mode 100644
index 000000000..6b2b548a6
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class PasswordSecurity
+{
+public:
+ struct Configuration
+ {
+ std::string Password;
+ bool ProtectMachineLocalRequests = false;
+ std::vector<std::string> UnprotectedUris;
+ };
+
+ explicit PasswordSecurity(const Configuration& Config);
+
+ [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; }
+ [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; }
+ [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const;
+
+ bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest);
+
+private:
+ const Configuration m_Config;
+ tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes;
+};
+
+void passwordsecurity_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
new file mode 100644
index 000000000..c098f05ad
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
@@ -0,0 +1,51 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <zenhttp/security/passwordsecurity.h>
+
+namespace zen {
+
+class PasswordHttpFilter : public IHttpRequestFilter
+{
+public:
+ static constexpr std::string_view TypeName = "password";
+
+ struct Configuration
+ {
+ PasswordSecurity::Configuration PasswordConfig;
+ std::string AuthenticationTypeString;
+ };
+
+ /**
+ * Expected format (Json)
+ * {
+ * "password": { # "Authorization: Basic <username:password base64 encoded>" style
+ * "username": "<username>",
+ * "password": "<password>"
+ * },
+ * "protect-machine-local-requests": false,
+ * "unprotected-uris": [
+ * "/health/",
+ * "/health/info",
+ * "/health/version"
+ * ]
+ * }
+ */
+ static Configuration ReadConfiguration(CbObjectView Config);
+
+ explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config)
+ : m_PasswordSecurity(Config.PasswordConfig)
+ , m_AuthenticationTypeString(Config.AuthenticationTypeString)
+ {
+ }
+
+ virtual Result FilterRequest(HttpServerRequest& Request) override;
+
+private:
+ PasswordSecurity m_PasswordSecurity;
+ const std::string m_AuthenticationTypeString;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..bc3293282
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+#include <zencore/iobuffer.h>
+
+#include <cstdint>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+enum class WebSocketOpcode : uint8_t
+{
+ kText = 0x1,
+ kBinary = 0x2,
+ kClose = 0x8,
+ kPing = 0x9,
+ kPong = 0xA
+};
+
+struct WebSocketMessage
+{
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ IoBuffer Payload;
+ uint16_t CloseCode = 0;
+};
+
+/**
+ * Represents an active WebSocket connection
+ *
+ * Derived classes implement the actual transport (e.g. ASIO sockets).
+ * Instances are reference-counted so that both the service layer and
+ * the async read/write loop can share ownership.
+ */
+class WebSocketConnection : public RefCounted
+{
+public:
+ virtual ~WebSocketConnection() = default;
+
+ virtual void SendText(std::string_view Text) = 0;
+ virtual void SendBinary(std::span<const uint8_t> Data) = 0;
+ virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0;
+ virtual bool IsOpen() const = 0;
+};
+
+/**
+ * Interface for services that accept WebSocket upgrades
+ *
+ * An HttpService may additionally implement this interface to indicate
+ * it supports WebSocket connections. The HTTP server checks for this
+ * via dynamic_cast when it sees an Upgrade: websocket request.
+ */
+class IWebSocketHandler
+{
+public:
+ virtual ~IWebSocketHandler() = default;
+
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
+ virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index b097a0d3f..2370def0c 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -3,15 +3,57 @@
#include "zenhttp/httpstats.h"
#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
namespace zen {
-HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats"))
+HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats"))
{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushThread = std::thread([this] { PushThreadFunction(); });
+ }
+}
+
+HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats"))
+{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushTimer = std::make_unique<asio::steady_timer>(IoContext);
+ EnqueuePushTimer();
+ }
}
HttpStatsService::~HttpStatsService()
{
+ Shutdown();
+}
+
+void
+HttpStatsService::Shutdown()
+{
+ if (!m_PushEnabled.exchange(false))
+ {
+ return;
+ }
+
+ if (m_PushTimer)
+ {
+ m_PushTimer->cancel();
+ m_PushTimer.reset();
+ }
+
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ m_PushThread.join();
+ }
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); });
}
const char*
@@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro
void
HttpStatsService::HandleRequest(HttpServerRequest& Request)
{
+ ZEN_TRACE_CPU("HttpStatsService::HandleRequest");
using namespace std::literals;
std::string_view Key = Request.RelativeUri();
@@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
}
}
+//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+void
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
+ ZEN_INFO("Stats WebSocket client connected");
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+
+ // Send initial state immediately
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ }
+}
+
+void
+HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/)
+{
+ // No client-to-server messages expected
+}
+
+void
+HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose");
+ ZEN_INFO("Stats WebSocket client disconnected (code {})", Code);
+
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Stats broadcast
+//
+
+void
+HttpStatsService::BroadcastStats()
+{
+ ZEN_TRACE_CPU("HttpStatsService::BroadcastStats");
+ std::vector<Ref<WebSocketConnection>> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; });
+
+ if (Connections.empty())
+ {
+ return;
+ }
+
+ // Collect stats from all providers
+ ExtendableStringBuilder<4096> JsonBuilder;
+ JsonBuilder.Append("{");
+
+ bool First = true;
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ for (auto& [Id, Provider] : m_Providers)
+ {
+ CbObject Stats = Provider->CollectStats();
+ if (!Stats)
+ {
+ continue;
+ }
+
+ if (!First)
+ {
+ JsonBuilder.Append(",");
+ }
+ First = false;
+
+ // Emit as "provider_id": { ... }
+ JsonBuilder.Append("\"");
+ JsonBuilder.Append(Id);
+ JsonBuilder.Append("\":");
+
+ ExtendableStringBuilder<2048> StatsJson;
+ Stats.ToJson(StatsJson);
+ JsonBuilder.Append(StatsJson.ToView());
+ }
+ }
+
+ JsonBuilder.Append("}");
+
+ std::string_view Json = JsonBuilder.ToView();
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Json);
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Thread-based push (fallback when no io_context)
+//
+
+void
+HttpStatsService::PushThreadFunction()
+{
+ SetCurrentThreadName("stats_ws_push");
+
+ while (m_PushEnabled.load())
+ {
+ m_PushEvent.Wait(5000);
+ m_PushEvent.Reset();
+
+ if (!m_PushEnabled.load())
+ {
+ break;
+ }
+
+ BroadcastStats();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Timer-based push (when io_context is provided)
+//
+
+void
+HttpStatsService::EnqueuePushTimer()
+{
+ if (!m_PushTimer)
+ {
+ return;
+ }
+
+ m_PushTimer->expires_after(std::chrono::seconds(5));
+ m_PushTimer->async_wait([this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return;
+ }
+
+ BroadcastStats();
+ EnqueuePushTimer();
+ });
+}
+
} // namespace zen
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 708238224..cbfe4d889 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -581,7 +581,7 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
- Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash));
}
else
{
@@ -805,6 +805,8 @@ CbPackageReader::Finalize()
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.packageformat");
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -926,6 +928,8 @@ TEST_CASE("CbPackage.LocalRef")
Reader.Finalize();
}
+TEST_SUITE_END();
+
void
forcelink_packageformat()
{
diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp
new file mode 100644
index 000000000..0e3a743c3
--- /dev/null
+++ b/src/zenhttp/security/passwordsecurity.cpp
@@ -0,0 +1,176 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenhttp/security/passwordsecurity.h"
+#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
+#include <zencore/string.h>
+
+#if ZEN_WITH_TESTS
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/testing.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+using namespace std::literals;
+
+PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config)
+{
+ m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size());
+ for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++)
+ {
+ const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index];
+ if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second)
+ {
+ throw std::runtime_error(fmt::format(
+ "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')",
+ Index + 1,
+ UnprotectedUri,
+ Result.first->second + 1,
+ m_Config.UnprotectedUris[Result.first->second]));
+ }
+ }
+}
+
+bool
+PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const
+{
+ if (!m_Config.UnprotectedUris.empty())
+ {
+ uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri});
+ if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end())
+ {
+ const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second];
+ if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length())
+ {
+ if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri)
+ {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+bool
+PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest)
+{
+ if (IsUnprotectedUri(BaseUri, RelativeUri))
+ {
+ return true;
+ }
+ if (!ProtectMachineLocalRequests() && IsMachineLocalRequest)
+ {
+ return true;
+ }
+ if (Password().empty())
+ {
+ return true;
+ }
+ if (Password() == InPassword)
+ {
+ return true;
+ }
+ return false;
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("http.passwordsecurity");
+
+TEST_CASE("passwordsecurity.allowanything")
+{
+ PasswordSecurity Anything({});
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+}
+
+TEST_CASE("passwordsecurity.allowalllocal")
+{
+ PasswordSecurity AllLocal({.Password = "123456"});
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+}
+
+TEST_CASE("passwordsecurity.allowonlypassword")
+{
+ PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true});
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.allowsomeexternaluris")
+{
+ PasswordSecurity AllLocal(
+ {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.allowsomelocaluris")
+{
+ PasswordSecurity AllLocal(
+ {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+}
+
+TEST_CASE("passwordsecurity.conflictingunprotecteduris")
+{
+ try
+ {
+ PasswordSecurity AllLocal({.Password = "123456",
+ .ProtectMachineLocalRequests = true,
+ .UnprotectedUris = std::vector<std::string>({"/free/access", "/free/access"})});
+ CHECK(false);
+ }
+ catch (const std::runtime_error& Ex)
+ {
+ CHECK_EQ(Ex.what(),
+ std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with "
+ "uri #1 ('/free/access')"));
+ }
+}
+
+TEST_SUITE_END();
+
+void
+passwordsecurity_forcelink()
+{
+}
+#endif // ZEN_WITH_TESTS
+
+} // namespace zen
diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp
new file mode 100644
index 000000000..87d8cc275
--- /dev/null
+++ b/src/zenhttp/security/passwordsecurityfilter.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenhttp/security/passwordsecurityfilter.h"
+
+#include <zencore/base64.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+PasswordHttpFilter::Configuration
+PasswordHttpFilter::ReadConfiguration(CbObjectView Config)
+{
+ Configuration Result;
+ if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType)
+ {
+ Result.AuthenticationTypeString = "Basic ";
+ std::string_view Username = PasswordType["username"sv].AsString();
+ std::string_view Password = PasswordType["password"sv].AsString();
+ std::string UsernamePassword = fmt::format("{}:{}", Username, Password);
+ Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length())));
+ Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()),
+ uint32_t(UsernamePassword.size()),
+ const_cast<char*>(Result.PasswordConfig.Password.data()));
+ }
+ Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool();
+ Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config);
+ return Result;
+}
+
+IHttpRequestFilter::Result
+PasswordHttpFilter::FilterRequest(HttpServerRequest& Request)
+{
+ std::string_view Password;
+ std::string_view AuthorizationHeader = Request.GetAuthorizationHeader();
+ size_t AuthorizationHeaderLength = AuthorizationHeader.length();
+ if (AuthorizationHeaderLength > m_AuthenticationTypeString.length())
+ {
+ if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0)
+ {
+ Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length());
+ }
+ }
+
+ bool IsAllowed =
+ m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest());
+ if (IsAllowed)
+ {
+ return Result::Accepted;
+ }
+ return Result::Forbidden;
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 18a0f6a40..f5178ebe8 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -7,12 +7,15 @@
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory/llm.h>
+#include <zencore/system.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zencore/windows.h>
#include <zenhttp/httpserver.h>
#include "httpparser.h"
+#include "wsasio.h"
+#include "wsframecodec.h"
#include <EASTL/fixed_vector.h>
@@ -89,15 +92,19 @@ IsIPv6AvailableSysctl(void)
char buf[16];
if (fgets(buf, sizeof(buf), f))
{
- fclose(f);
// 0 means IPv6 enabled, 1 means disabled
val = atoi(buf);
}
+ fclose(f);
}
return val == 0;
}
+#endif // ZEN_PLATFORM_LINUX
+namespace zen {
+
+#if ZEN_PLATFORM_LINUX
bool
IsIPv6Capable()
{
@@ -121,8 +128,6 @@ IsIPv6Capable()
}
#endif
-namespace zen {
-
const FLLMTag&
GetHttpasioTag()
{
@@ -145,7 +150,7 @@ inline LoggerRef
InitLogger()
{
LoggerRef Logger = logging::Get("asio");
- // Logger.set_level(spdlog::level::trace);
+ // Logger.SetLogLevel(logging::Trace);
return Logger;
}
@@ -496,16 +501,21 @@ public:
HttpAsioServerImpl();
~HttpAsioServerImpl();
- void Initialize(std::filesystem::path DataDir);
- int Start(uint16_t Port, const AsioConfig& Config);
- void Stop();
- void RegisterService(const char* UrlPath, HttpService& Service);
- HttpService* RouteRequest(std::string_view Url);
+ void Initialize(std::filesystem::path DataDir);
+ int Start(uint16_t Port, const AsioConfig& Config);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
+
+ bool IsLoopbackOnly() const;
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -518,6 +528,11 @@ public:
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
+
+ HttpServer* m_HttpServer = nullptr;
};
/**
@@ -527,12 +542,21 @@ public:
class HttpAsioServerRequest : public HttpServerRequest
{
public:
- HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber);
+ HttpAsioServerRequest(HttpRequestParser& Request,
+ HttpService& Service,
+ IoBuffer PayloadBuffer,
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress);
~HttpAsioServerRequest();
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -548,6 +572,8 @@ public:
HttpRequestParser& m_Request;
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
+ bool m_IsLocalMachineRequest;
+ std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -925,6 +951,7 @@ private:
void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop);
void CloseConnection();
+ void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {});
HttpAsioServerImpl& m_Server;
asio::streambuf m_RequestBuffer;
@@ -1025,6 +1052,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+ m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
m_RequestCounter.load(std::memory_order_relaxed),
@@ -1078,6 +1107,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
return;
}
+ m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
RequestNumber,
@@ -1139,10 +1170,91 @@ HttpServerConnection::CloseConnection()
}
void
+HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
+{
+ ExtendableStringBuilder<256> ResponseBuilder;
+ ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
+ if (!Headers.empty())
+ {
+ ResponseBuilder << Headers;
+ }
+ if (!m_RequestData.IsKeepAlive())
+ {
+ ResponseBuilder << "Connection: close\r\n";
+ }
+ ResponseBuilder << "\r\n";
+ if (!Body.empty())
+ {
+ ResponseBuilder << Body;
+ }
+ auto ResponseView = ResponseBuilder.ToView();
+ IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
+ auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
+ asio::async_write(
+ *m_Socket.get(),
+ Buffer,
+ [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
+ });
+}
+
+void
HttpServerConnection::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ // WebSocket upgrade detection must happen before the keep-alive check below,
+ // because Upgrade requests have "Connection: Upgrade" which the HTTP parser
+ // treats as non-keep-alive, causing a premature shutdown of the receive side.
+ if (m_RequestData.IsWebSocketUpgrade())
+ {
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service);
+ if (WsHandler && !m_RequestData.SecWebSocketKey().empty())
+ {
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey());
+
+ auto ResponseStr = std::make_shared<std::string>();
+ ResponseStr->reserve(256);
+ ResponseStr->append(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ");
+ ResponseStr->append(AcceptKey);
+ ResponseStr->append("\r\n\r\n");
+
+ // Send the 101 response on the current socket, then hand the socket off
+ // to a WsAsioConnection for the WebSocket protocol.
+ asio::async_write(*m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ Ref<WsAsioConnection> WsConn(
+ new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
+
+ m_RequestState = RequestState::kDone;
+ return;
+ }
+ }
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
@@ -1166,14 +1278,24 @@ HttpServerConnection::HandleRequest()
{
ZEN_TRACE_CPU("asio::HandleRequest");
- HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber);
+ m_Server.m_HttpServer->MarkRequest();
+
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+
+ HttpAsioServerRequest Request(m_RequestData,
+ *Service,
+ m_RequestData.Body(),
+ RequestNumber,
+ IsLocalConnection,
+ RemoteEndpoint.address().to_string());
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server.m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server.m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -1188,56 +1310,73 @@ HttpServerConnection::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
{
+ if (Request.ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength()));
+ }
+
// Transmit the response
if (m_RequestData.RequestVerb() == HttpVerb::kHead)
@@ -1278,51 +1417,24 @@ HttpServerConnection::HandleRequest()
}
}
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect();
+ if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty()))
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "\r\n"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Connection: close\r\n"
- "\r\n"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ ExtendableStringBuilder<128> Headers;
+ Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n";
+ SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView());
+ }
+ else if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ SendInlineResponse(RequestNumber, "404 NOT FOUND"sv);
}
else
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "\r\n"
- "No suitable route found"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "Connection: close\r\n"
- "\r\n"
- "No suitable route found"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ SendInlineResponse(RequestNumber,
+ "404 NOT FOUND"sv,
+ "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv,
+ "No suitable route found"sv);
}
}
@@ -1348,8 +1460,11 @@ struct HttpAcceptor
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
#else // ZEN_PLATFORM_WINDOWS
- m_Acceptor.set_option(asio::socket_base::reuse_address(false));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false));
+ // Allow binding to a port in TIME_WAIT so the server can restart immediately
+ // after a previous instance exits. On Linux this does not allow two processes
+ // to actively listen on the same port simultaneously.
+ m_Acceptor.set_option(asio::socket_base::reuse_address(true));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true));
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
@@ -1512,7 +1627,7 @@ struct HttpAcceptor
{
ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
- return 0;
+ return {};
}
if (EffectivePort != BasePort)
@@ -1569,7 +1684,8 @@ struct HttpAcceptor
void StopAccepting() { m_IsStopped = true; }
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+ int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
bool IsValid() const { return m_IsValid; }
@@ -1632,11 +1748,15 @@ private:
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
HttpService& Service,
IoBuffer PayloadBuffer,
- uint32_t RequestNumber)
+ uint32_t RequestNumber,
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress)
: HttpServerRequest(Service)
, m_Request(Request)
, m_RequestNumber(RequestNumber)
, m_PayloadBuffer(std::move(PayloadBuffer))
+, m_IsLocalMachineRequest(IsLocalMachineRequest)
+, m_RemoteAddress(std::move(RemoteAddress))
{
const int PrefixLength = Service.UriPrefixLength();
@@ -1708,6 +1828,24 @@ HttpAsioServerRequest::ParseRequestId() const
return m_Request.RequestId();
}
+bool
+HttpAsioServerRequest::IsLocalMachineRequest() const
+{
+ return m_IsLocalMachineRequest;
+}
+
+std::string_view
+HttpAsioServerRequest::GetRemoteAddress() const
+{
+ return m_RemoteAddress;
+}
+
+std::string_view
+HttpAsioServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
IoBuffer
HttpAsioServerRequest::ReadPayload()
{
@@ -1904,6 +2042,37 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+void
+HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+
+ return RequestFilter->FilterRequest(Request);
+}
+
+bool
+HttpAsioServerImpl::IsLoopbackOnly() const
+{
+ return m_Acceptor && m_Acceptor->IsLoopbackOnly();
+}
+
} // namespace zen::asio_http
//////////////////////////////////////////////////////////////////////////
@@ -1916,11 +2085,15 @@ public:
HttpAsioServer(const AsioConfig& Config);
~HttpAsioServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
private:
Event m_ShutdownEvent;
@@ -1934,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config)
: m_InitialConfig(Config)
, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
{
+ m_Impl->m_HttpServer = this;
ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser));
}
@@ -1965,6 +2139,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service)
m_Impl->RegisterService(Service.BaseUri(), Service);
}
+void
+HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ m_Impl->SetHttpRequestFilter(RequestFilter);
+}
+
int
HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -1989,10 +2169,46 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
return m_BasePort;
}
+std::string
+HttpAsioServer::OnGetExternalHost() const
+{
+ if (m_Impl->IsLoopbackOnly())
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesReceived() const
+{
+ return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesSent() const
+{
+ return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpAsioServer::OnRun(bool IsInteractive)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractive)
@@ -2008,12 +2224,13 @@ HttpAsioServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractive)
{
@@ -2022,8 +2239,8 @@ HttpAsioServer::OnRun(bool IsInteractive)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index c483dfc28..3ec1141a7 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -15,4 +15,6 @@ struct AsioConfig
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
+bool IsIPv6Capable();
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 31cb04be5..584e06cbf 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -54,9 +54,19 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
}
void
+HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ for (auto& Server : m_Servers)
+ {
+ Server->SetHttpRequestFilter(RequestFilter);
+ }
+}
+
+void
HttpMultiServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -72,12 +82,13 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -86,8 +97,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
@@ -106,6 +117,16 @@ HttpMultiServer::OnClose()
}
}
+std::string
+HttpMultiServer::OnGetExternalHost() const
+{
+ if (!m_Servers.empty())
+ {
+ return std::string(m_Servers.front()->GetExternalHost());
+ }
+ return HttpServer::OnGetExternalHost();
+}
+
void
HttpMultiServer::AddServer(Ref<HttpServer> Server)
{
diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h
index ae0ed74cf..97699828a 100644
--- a/src/zenhttp/servers/httpmulti.h
+++ b/src/zenhttp/servers/httpmulti.h
@@ -15,11 +15,13 @@ public:
HttpMultiServer();
~HttpMultiServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
void AddServer(Ref<HttpServer> Server);
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index 0ec1cb3c4..9bb7ef3bc 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service)
ZEN_UNUSED(Service);
}
+void
+HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_UNUSED(RequestFilter);
+}
+
int
HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
@@ -34,7 +40,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
void
HttpNullServer::OnRun(bool IsInteractiveSession)
{
- const int WaitTimeout = 1000;
+ const int WaitTimeout = 1000;
+ bool ShutdownRequested = false;
#if ZEN_PLATFORM_WINDOWS
if (IsInteractiveSession)
@@ -50,12 +57,13 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#else
if (IsInteractiveSession)
{
@@ -64,8 +72,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
do
{
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!ShutdownRequested);
#endif
}
diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h
index ce7230938..52838f012 100644
--- a/src/zenhttp/servers/httpnull.h
+++ b/src/zenhttp/servers/httpnull.h
@@ -18,6 +18,7 @@ public:
~HttpNullServer();
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 93094e21b..918b55dc6 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -12,13 +12,17 @@ namespace zen {
using namespace std::literals;
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv);
+static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv);
+static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -142,41 +146,62 @@ HttpRequestParser::ParseCurrentHeader()
const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
- if (HeaderHash == HashContentLength)
+ switch (HeaderHash)
{
- m_ContentLengthHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::TryFromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = CurrentHeaderIndex;
+ case HashContentLength:
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAccept:
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashContentType:
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAuthorization:
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSession:
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
+ break;
+
+ case HashRequest:
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ break;
+
+ case HashExpect:
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ break;
+
+ case HashRange:
+ m_RangeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashUpgrade:
+ m_UpgradeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketKey:
+ m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketVersion:
+ m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ default:
+ break;
}
}
@@ -220,11 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl)
NormalizedUrl.reserve(UrlLength);
NormalizedUrl.append(Url, UrlIndex);
}
-
- if (!LastCharWasSeparator)
- {
- NormalizedUrl.push_back('/');
- }
}
else if (!NormalizedUrl.empty())
{
@@ -305,6 +325,7 @@ HttpRequestParser::OnHeadersComplete()
if (ContentLength)
{
+ // TODO: should sanity-check content length here
m_BodyBuffer = IoBuffer(ContentLength);
}
@@ -324,9 +345,9 @@ HttpRequestParser::OnHeadersComplete()
int
HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
- if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size())
{
- ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
return 1;
}
@@ -337,7 +358,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
if (m_BodyPosition != m_BodyBuffer.Size())
{
- ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
return 1;
}
}
@@ -353,13 +374,18 @@ HttpRequestParser::ResetState()
m_HeaderEntries.clear();
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
+ m_UpgradeHeaderIndex = -1;
+ m_SecWebSocketKeyHeaderIndex = -1;
+ m_SecWebSocketVersionHeaderIndex = -1;
+ m_RequestVerb = HttpVerb::kGet;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
m_HeaderData.clear();
m_NormalizedUrl.clear();
@@ -416,4 +442,21 @@ HttpRequestParser::OnMessageComplete()
}
}
+bool
+HttpRequestParser::IsWebSocketUpgrade() const
+{
+ std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex);
+ if (Upgrade.empty())
+ {
+ return false;
+ }
+
+ // Case-insensitive check for "websocket"
+ if (Upgrade.size() != 9)
+ {
+ return false;
+ }
+ return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
+}
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 0d2664ec5..23ad9d8fb 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -46,6 +46,12 @@ struct HttpRequestParser
std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); }
+ std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+
+ std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); }
+ std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); }
+ bool IsWebSocketUpgrade() const;
+
private:
struct HeaderRange
{
@@ -83,7 +89,11 @@ private:
int8_t m_AcceptHeaderIndex;
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
+ int8_t m_AuthorizationHeaderIndex;
+ int8_t m_UpgradeHeaderIndex;
+ int8_t m_SecWebSocketKeyHeaderIndex;
+ int8_t m_SecWebSocketVersionHeaderIndex;
+ HttpVerb m_RequestVerb = HttpVerb::kGet;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
int m_RequestId = -1;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index b9217ed87..4bf8c61bb 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
// HttpPluginServer
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
@@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
virtual void AddPlugin(Ref<TransportPlugin> Plugin) override;
virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override;
- HttpService* RouteRequest(std::string_view Url);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
struct ServiceEntry
{
@@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* Service;
};
- bool m_IsInitialized = false;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool m_IsInitialized = false;
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
std::vector<Ref<TransportPlugin>> m_Plugins;
@@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- int m_BasePort;
+ int m_BasePort = 0;
HttpServerTracer m_RequestTracer;
@@ -143,8 +146,11 @@ public:
HttpPluginServerRequest(const HttpPluginServerRequest&) = delete;
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
+ // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
+ virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug
ConnectionName = "anonymous";
}
- ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName);
+ ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName);
}
uint32_t
@@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest()
{
ZEN_TRACE_CPU("http_plugin::HandleRequest");
+ m_Server->MarkRequest();
+
HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server->m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
{
@@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest()
const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
ResponseBuffers);
@@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest()
{
}
+std::string_view
+HttpPluginServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
Oid
HttpPluginServerRequest::ParseSessionId() const
{
@@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir
}
void
+HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+void
HttpPluginServerImpl::OnClose()
{
if (!m_IsInitialized)
@@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+IHttpRequestFilter::Result
+HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 54cc0c22d..dfe6bb6aa 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
OutString.Append("unknown");
}
+class HttpSysServerRequest;
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -90,17 +97,23 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+ IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request);
+
private:
int InitializeServer(int BasePort);
void Cleanup();
@@ -124,8 +137,8 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
@@ -137,6 +150,12 @@ private:
int32_t m_MaxPendingRequests = 128;
Event m_ShutdownEvent;
HttpSysConfig m_InitialConfig;
+
+ RwLock m_RequestFilterLock;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -144,6 +163,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -313,6 +336,10 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
+
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
@@ -320,16 +347,19 @@ public:
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
virtual bool TryGetRanges(HttpRanges& Ranges) override;
+ void LogRequest(HttpMessageResponseRequest* Response);
+
using HttpServerRequest::WriteResponse;
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -363,7 +393,7 @@ public:
PTP_IO Iocp();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -380,8 +410,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -418,7 +448,10 @@ public:
virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
void SuppressResponseBody(); // typically used for HEAD requests
- inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ inline uint16_t GetResponseCode() const { return m_ResponseCode; }
+ inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -429,6 +462,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
HTTP_CACHE_POLICY CachePolicy;
- CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.Policy = HttpCachePolicyNocache;
CachePolicy.SecondsToLive = 0;
// Initial response API call
- SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- &HttpResponse,
- &CachePolicy,
- NULL,
- NULL,
- 0,
- Tx.Overlapped(),
- NULL);
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
+ &HttpResponse, // HttpResponse
+ &CachePolicy, // CachePolicy
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
m_IsInitialResponse = false;
}
@@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
{
// Subsequent response API calls
- SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle
+ HttpReq->RequestId, // RequestId
+ SendFlags, // Flags
(USHORT)ThisRequestChunkCount, // EntityChunkCount
&m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
NULL, // BytesSent
@@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
return 0;
}
@@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort)
// port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
- ZEN_WARN(
- "Unable to register handler using '{}' - falling back to local-only. "
- "Please ensure the appropriate netsh URL reservation configuration "
- "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
- WideToUtf8(WildcardUrlPath));
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation configuration "
+ "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)",
+ WideToUtf8(WildcardUrlPath));
+ }
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (m_BaseUris.empty())
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
@@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
@@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort)
ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
-
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
- {
- ZEN_DEBUG("HTTP3 is available");
- }
- else
- {
- ZEN_DEBUG("HTTP3 is NOT available");
- }
-# endif
-
return EffectivePort;
}
@@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
- m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async");
+ m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive)
ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit");
}
+ bool ShutdownRequested = false;
do
{
- // int WaitTimeout = -1;
int WaitTimeout = 100;
if (IsInteractive)
@@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
}
- m_ShutdownEvent.Wait(WaitTimeout);
+ ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout);
UpdateLofreqTimerValue();
- } while (!IsApplicationExitRequested());
+ } while (!ShutdownRequested);
}
void
@@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
{
@@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator);
# endif
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
+ {
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+ }
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ ThisRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
{
- Service.HandleRequest(ThisRequest);
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
}
return ThisRequest;
@@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const
return 0;
}
+bool
+HttpSysServerRequest::IsLocalMachineRequest() const
+{
+ const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress;
+ const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ if (LocalAddress->sa_family != RemoteAddress->sa_family)
+ {
+ return false;
+ }
+ if (LocalAddress->sa_family == AF_INET)
+ {
+ const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress);
+ const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress);
+ return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr;
+ }
+ else if (LocalAddress->sa_family == AF_INET6)
+ {
+ const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress);
+ const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress);
+ return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
+HttpSysServerRequest::GetAuthorizationHeader() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization];
+ return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength);
+}
+
IoBuffer
HttpSysServerRequest::ReadPayload()
{
@@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
if (SuppressBody())
{
@@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
ZEN_ASSERT(IsHandled() == false);
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
if (SuppressBody())
{
@@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
+}
+
+void
+HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response)
+{
+ if (ShouldLogRequest())
+ {
+ ZEN_INFO("{} {} {} -> {}",
+ ToString(RequestVerb()),
+ m_UriUtf8.c_str(),
+ Response->GetResponseCode(),
+ NiceBytes(Response->GetResponseBodySize()));
+ }
}
void
@@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
# endif
SetIsHandled();
+ LogRequest(Response);
}
void
@@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
- ZEN_INFO("");
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
+
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_service IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{
RegisterService(Service.BaseUri(), Service);
}
+void
+HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_RequestFilterLock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpSysServer::FilterRequest(HttpSysServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_RequestFilterLock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{
diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h
new file mode 100644
index 000000000..4f8a97012
--- /dev/null
+++ b/src/zenhttp/servers/httpsys_iocontext.h
@@ -0,0 +1,40 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+
+# include <cstdint>
+
+namespace zen {
+
+/**
+ * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch
+ *
+ * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection
+ * (for WebSocket read/write) embed this struct. The single IoCompletionCallback
+ * bound to the request queue uses the ContextType tag to dispatch to the correct
+ * handler.
+ *
+ * The Overlapped member must be first so that CONTAINING_RECORD works to recover
+ * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool.
+ */
+struct HttpSysIoContext
+{
+ OVERLAPPED Overlapped{};
+
+ enum class Type : uint8_t
+ {
+ kTransaction,
+ kWebSocketRead,
+ kWebSocketWrite,
+ } ContextType = Type::kTransaction;
+
+ void* Owner = nullptr;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h
index da72c79c9..a9a45f162 100644
--- a/src/zenhttp/servers/httptracer.h
+++ b/src/zenhttp/servers/httptracer.h
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenhttp/httpserver.h>
-
#pragma once
+#include <zenhttp/httpserver.h>
+
namespace zen {
/** Helper class for HTTP server implementations
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
new file mode 100644
index 000000000..b2543277a
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -0,0 +1,311 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsasio.h"
+#include "wsframecodec.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen::asio_http {
+
+static LoggerRef
+WsLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+: m_Socket(std::move(Socket))
+, m_Handler(Handler)
+, m_HttpServer(Server)
+{
+}
+
+WsAsioConnection::~WsAsioConnection()
+{
+ m_IsOpen.store(false);
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsAsioConnection::Start()
+{
+ EnqueueRead();
+}
+
+bool
+WsAsioConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Read loop
+//
+
+void
+WsAsioConnection::EnqueueRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
+ Self->OnDataReceived(Ec, ByteCount);
+ });
+}
+
+void
+WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+}
+
+void
+WsAsioConnection::ProcessReceivedData()
+{
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Shut down the socket
+ std::error_code ShutdownEc;
+ m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
+ m_Socket->close(ShutdownEc);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Write queue
+//
+
+void
+WsAsioConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+}
+
+void
+WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsAsioConnection::FlushWriteQueue()
+{
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ Ref<WsAsioConnection> Self(this);
+
+ // Move Frame into a shared_ptr so we can create the buffer and capture ownership
+ // in the same async_write call without evaluation order issues.
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
+}
+
+void
+WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
new file mode 100644
index 000000000..e8bb3b1d2
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.h
@@ -0,0 +1,77 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/thread.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+namespace zen {
+class HttpServer;
+} // namespace zen
+
+namespace zen::asio_http {
+
+/**
+ * WebSocket connection over an ASIO TCP socket
+ *
+ * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * and runs an async read/write loop to exchange WebSocket frames.
+ *
+ * Lifetime is managed solely through intrusive reference counting (RefCounted).
+ * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
+ * connection alive for the duration of the async operation. The service layer
+ * also holds a Ref<WebSocketConnection>.
+ */
+
+class WsAsioConnection : public WebSocketConnection
+{
+public:
+ WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and the 101 response has been sent.
+ */
+ void Start();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+private:
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void ProcessReceivedData();
+
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount);
+
+ void DoClose(uint16_t Code, std::string_view Reason);
+
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
+
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
new file mode 100644
index 000000000..e452141fe
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/sha1.h>
+
+#include <cstring>
+#include <random>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+WsFrameParseResult
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+{
+ // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
+ if (Size < 2)
+ {
+ return {};
+ }
+
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ uint64_t PayloadLen = Data[1] & 0x7F;
+
+ size_t HeaderSize = 2;
+
+ if (PayloadLen == 126)
+ {
+ if (Size < 4)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]);
+ HeaderSize = 4;
+ }
+ else if (PayloadLen == 127)
+ {
+ if (Size < 10)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) |
+ (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]);
+ HeaderSize = 10;
+ }
+
+ // Reject frames with unreasonable payload sizes to prevent OOM
+ static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB
+ if (PayloadLen > kMaxPayloadSize)
+ {
+ return {};
+ }
+
+ const size_t MaskSize = Masked ? 4 : 0;
+ const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen;
+
+ if (Size < TotalFrame)
+ {
+ return {};
+ }
+
+ const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr;
+ const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
+
+ WsFrameParseResult Result;
+ Result.IsValid = true;
+ Result.BytesConsumed = TotalFrame;
+ Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
+ Result.Fin = Fin;
+
+ Result.Payload.resize(static_cast<size_t>(PayloadLen));
+ if (PayloadLen > 0)
+ {
+ std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen));
+
+ if (Masked)
+ {
+ for (size_t i = 0; i < Result.Payload.size(); ++i)
+ {
+ Result.Payload[i] ^= MaskKey[i & 3];
+ }
+ }
+ }
+
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (server-to-client, no masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length (no mask bit for server frames)
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ Frame.insert(Frame.end(), Payload.begin(), Payload.end());
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (client-to-server, with masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Generate random 4-byte mask key
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ uint32_t MaskValue = s_Rng();
+ uint8_t MaskKey[4];
+ std::memcpy(MaskKey, &MaskValue, 4);
+
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < PayloadLen; ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2)
+//
+
+static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+std::string
+WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey)
+{
+ // Concatenate client key with the magic GUID
+ std::string Combined;
+ Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size());
+ Combined.append(ClientKey);
+ Combined.append(kWebSocketMagicGuid);
+
+ // SHA1 hash
+ SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size());
+
+ // Base64 encode the 20-byte hash
+ char Base64Buf[Base64::GetEncodedDataSize(20) + 1];
+ uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf);
+ Base64Buf[EncodedLen] = '\0';
+
+ return std::string(Base64Buf, EncodedLen);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
new file mode 100644
index 000000000..2d90b6fa1
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+/**
+ * Result of attempting to parse a single WebSocket frame from a byte buffer
+ */
+struct WsFrameParseResult
+{
+ bool IsValid = false; // true if a complete frame was successfully parsed
+ size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ bool Fin = false;
+ std::vector<uint8_t> Payload;
+};
+
+/**
+ * RFC 6455 WebSocket frame codec
+ *
+ * Provides static helpers for parsing client-to-server frames (which are
+ * always masked) and building server-to-client frames (which are never masked).
+ */
+struct WsFrameCodec
+{
+ /**
+ * Try to parse one complete frame from the front of the buffer.
+ *
+ * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * there is not enough data yet. The caller should accumulate more data
+ * and retry.
+ */
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+
+ /**
+ * Build a server-to-client frame (no masking)
+ */
+ static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a close frame with a status code and optional reason string
+ */
+ static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Build a client-to-server frame (with masking per RFC 6455)
+ */
+ static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a masked close frame with status code and optional reason
+ */
+ static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2
+ *
+ * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ */
+ static std::string ComputeAcceptKey(std::string_view ClientKey);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
new file mode 100644
index 000000000..af320172d
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -0,0 +1,485 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wshttpsys.h"
+
+#if ZEN_WITH_HTTPSYS
+
+# include "wsframecodec.h"
+
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+namespace zen {
+
+static LoggerRef
+WsHttpSysLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws_httpsys");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle,
+ HTTP_REQUEST_ID RequestId,
+ IWebSocketHandler& Handler,
+ PTP_IO Iocp,
+ HttpServer* Server)
+: m_RequestQueueHandle(RequestQueueHandle)
+, m_RequestId(RequestId)
+, m_Handler(Handler)
+, m_Iocp(Iocp)
+, m_HttpServer(Server)
+, m_ReadBuffer(8192)
+{
+ m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead;
+ m_ReadIoContext.Owner = this;
+ m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite;
+ m_WriteIoContext.Owner = this;
+}
+
+WsHttpSysConnection::~WsHttpSysConnection()
+{
+ ZEN_ASSERT(m_OutstandingOps.load() == 0);
+
+ if (m_IsOpen.exchange(false))
+ {
+ Disconnect();
+ }
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsHttpSysConnection::Start()
+{
+ m_SelfRef = Ref<WsHttpSysConnection>(this);
+ IssueAsyncRead();
+}
+
+void
+WsHttpSysConnection::Shutdown()
+{
+ m_ShutdownRequested.store(true, std::memory_order_relaxed);
+
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+bool
+WsHttpSysConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async read path
+//
+
+void
+WsHttpSysConnection::IssueAsyncRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed))
+ {
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ 0, // Flags
+ m_ReadBuffer.data(),
+ (ULONG)m_ReadBuffer.size(),
+ nullptr, // BytesRead (ignored for async)
+ &m_ReadIoContext.Overlapped);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "read issue failed");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef
+ Ref<WsHttpSysConnection> Guard(this);
+
+ if (IoResult != NO_ERROR)
+ {
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ if (IoResult == ERROR_HANDLE_EOF)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection closed");
+ }
+ else if (IoResult != ERROR_OPERATION_ABORTED)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ if (NumberOfBytesTransferred > 0)
+ {
+ m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred);
+ ProcessReceivedData();
+ }
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ IssueAsyncRead();
+ }
+ else
+ {
+ MaybeReleaseSelfRef();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+void
+WsHttpSysConnection::ProcessReceivedData()
+{
+ while (!m_Accumulated.empty())
+ {
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ // Remove consumed bytes
+ m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+ Disconnect();
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async write path
+//
+
+void
+WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.push_back(std::move(Frame));
+
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ }
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsHttpSysConnection::FlushWriteQueue()
+{
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+
+ m_CurrentWriteBuffer = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk));
+ m_WriteChunk.DataChunkType = HttpDataChunkFromMemory;
+ m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data();
+ m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size();
+
+ ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_MORE_DATA,
+ 1,
+ &m_WriteChunk,
+ nullptr,
+ nullptr,
+ 0,
+ &m_WriteIoContext.Overlapped,
+ nullptr);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+ m_CurrentWriteBuffer.clear();
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ // Hold a transient ref to prevent mid-callback destruction
+ Ref<WsHttpSysConnection> Guard(this);
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+ m_CurrentWriteBuffer.clear();
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Send interface
+//
+
+void
+WsHttpSysConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Lifetime management
+//
+
+void
+WsHttpSysConnection::MaybeReleaseSelfRef()
+{
+ if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ m_SelfRef = nullptr;
+ }
+}
+
+void
+WsHttpSysConnection::Disconnect()
+{
+ // Send final empty body with DISCONNECT to tell http.sys the connection is done
+ HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_DISCONNECT,
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ nullptr,
+ nullptr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
new file mode 100644
index 000000000..6015e3873
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include "httpsys_iocontext.h"
+
+#include <zencore/thread.h>
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <http.h>
+
+# include <atomic>
+# include <deque>
+# include <vector>
+
+namespace zen {
+
+class HttpServer;
+
+/**
+ * WebSocket connection over an http.sys opaque-mode connection
+ *
+ * After the 101 Switching Protocols response is sent with
+ * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the
+ * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody /
+ * HttpSendResponseEntityBody using the original RequestId.
+ *
+ * All I/O is performed asynchronously via the same IOCP threadpool used
+ * for normal http.sys traffic, eliminating per-connection threads.
+ *
+ * Lifetime is managed through intrusive reference counting (RefCounted).
+ * A self-reference (m_SelfRef) is held from Start() until all outstanding
+ * async operations have drained, preventing premature destruction.
+ */
+class WsHttpSysConnection : public WebSocketConnection
+{
+public:
+ WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server);
+ ~WsHttpSysConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and after the 101 response has been sent.
+ */
+ void Start();
+
+ /**
+ * Shut down the connection. Cancels pending I/O; IOCP completions
+ * will fire with ERROR_OPERATION_ABORTED and drain naturally.
+ */
+ void Shutdown();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+ // Called from IoCompletionCallback via tagged dispatch
+ void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+ void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+private:
+ void IssueAsyncRead();
+ void ProcessReceivedData();
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void DoClose(uint16_t Code, std::string_view Reason);
+ void Disconnect();
+ void MaybeReleaseSelfRef();
+
+ HANDLE m_RequestQueueHandle;
+ HTTP_REQUEST_ID m_RequestId;
+ IWebSocketHandler& m_Handler;
+ PTP_IO m_Iocp;
+ HttpServer* m_HttpServer;
+
+ // Tagged OVERLAPPED contexts for concurrent read and write
+ HttpSysIoContext m_ReadIoContext{};
+ HttpSysIoContext m_WriteIoContext{};
+
+ // Read state
+ std::vector<uint8_t> m_ReadBuffer;
+ std::vector<uint8_t> m_Accumulated;
+
+ // Write state
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ std::vector<uint8_t> m_CurrentWriteBuffer;
+ HTTP_DATA_CHUNK m_WriteChunk{};
+ bool m_IsWriting = false;
+
+ // Lifetime management
+ std::atomic<int32_t> m_OutstandingOps{0};
+ Ref<WsHttpSysConnection> m_SelfRef;
+ std::atomic<bool> m_ShutdownRequested{false};
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
new file mode 100644
index 000000000..2134e4ff1
--- /dev/null
+++ b/src/zenhttp/servers/wstest.cpp
@@ -0,0 +1,925 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenhttp/websocket.h>
+
+# include "httpasio.h"
+# include "wsframecodec.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_PLATFORM_WINDOWS
+# include <winsock2.h>
+# else
+# include <poll.h>
+# include <sys/socket.h>
+# endif
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+# include <atomic>
+# include <chrono>
+# include <cstring>
+# include <random>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <vector>
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests: WsFrameCodec
+//
+
+TEST_SUITE_BEGIN("http.wstest");
+
+TEST_CASE("websocket.framecodec")
+{
+ SUBCASE("ComputeAcceptKey RFC 6455 test vector")
+ {
+ // RFC 6455 section 4.2.2 example
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ // Server frames are unmasked — TryParseFrame should handle them
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure");
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+
+ SUBCASE("TryParseFrame - partial data returns invalid")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ // Pass only 1 byte — not enough for a frame header
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
+ CHECK_FALSE(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, 0u);
+ }
+
+ SUBCASE("TryParseFrame - empty payload")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK(Result.Payload.empty());
+ }
+
+ SUBCASE("TryParseFrame - masked client frame")
+ {
+ // Build a masked frame manually as a client would send
+ // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello"
+ uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D};
+ uint8_t MaskedPayload[5] = {};
+ const char* Original = "Hello";
+ for (int i = 0; i < 5; ++i)
+ {
+ MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4];
+ }
+
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x81); // FIN + text
+ Frame.push_back(0x85); // MASK + len=5
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), 5u);
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, masked WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ // Verify mask bit is set
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildMaskedCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: WebSocket over ASIO
+//
+
+namespace {
+
+ /**
+ * Helper: Build a masked client-to-server frame per RFC 6455
+ */
+ std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Mask key (use a fixed key for deterministic tests)
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
+ std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
+ {
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ return BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ }
+
+ std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code)
+ {
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+ }
+
+ /**
+ * Test service that implements IWebSocketHandler
+ */
+ struct WsTestService : public HttpService, public IWebSocketHandler
+ {
+ const char* BaseUri() const override { return "/wstest/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest");
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_OpenCount.fetch_add(1);
+
+ m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+
+ // Echo the message back
+ Conn.SendText(Text);
+ }
+ }
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+
+ m_ConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_Connections.erase(It, m_Connections.end());
+ });
+ }
+
+ void SendToAll(std::string_view Text)
+ {
+ RwLock::SharedLockScope _(m_ConnectionsLock);
+ for (auto& Conn : m_Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Text);
+ }
+ }
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+
+ RwLock m_ConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_Connections;
+ };
+
+ /**
+ * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket
+ *
+ * Returns true on success (101 response), false otherwise.
+ */
+ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port)
+ {
+ // Send HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << Path << " HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ // Read the response (look for "101")
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ return Response.find("101") != std::string::npos;
+ }
+
+ /**
+ * Helper: Read a single server-to-client frame from a socket
+ *
+ * Uses a background thread with a synchronous ASIO read and a timeout.
+ */
+ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000)
+ {
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Result;
+ std::atomic<bool> Done{false};
+
+ std::thread Reader([&] {
+ while (!Done.load())
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (Frame.IsValid)
+ {
+ Result = std::move(Frame);
+ Done.store(true);
+ return;
+ }
+ }
+ });
+
+ auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs);
+ while (!Done.load() && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ if (!Done.load())
+ {
+ // Timeout — cancel the read
+ asio::error_code Ec;
+ Sock.cancel(Ec);
+ }
+
+ if (Reader.joinable())
+ {
+ Reader.join();
+ }
+
+ return Result;
+ }
+
+} // anonymous namespace
+
+TEST_CASE("websocket.integration")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ // Give server a moment to start accepting
+ Sleep(100);
+
+ SUBCASE("handshake succeeds with 101")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ CHECK(Ok);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_OpenCount.load(), 1);
+
+ Sock.close();
+ }
+
+ SUBCASE("normal HTTP still works alongside WebSocket service")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Send a normal HTTP GET (not upgrade)
+ std::string HttpReq = fmt::format(
+ "GET /wstest/hello HTTP/1.1\r\n"
+ "Host: 127.0.0.1:{}\r\n"
+ "Connection: close\r\n"
+ "\r\n",
+ Port);
+
+ asio::write(Sock, asio::buffer(HttpReq));
+
+ asio::streambuf ResponseBuf;
+ asio::error_code Ec;
+ asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec);
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+ CHECK(Response.find("200") != std::string::npos);
+ }
+
+ SUBCASE("echo message roundtrip")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a text message (masked, as client)
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test");
+ asio::write(Sock, asio::buffer(Frame));
+
+ // Read the echo reply
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "ping test"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "ping test");
+
+ Sock.close();
+ }
+
+ SUBCASE("server push to client")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Server pushes a message
+ TestService.SendToAll("server says hello");
+
+ WsFrameParseResult Frame = ReadOneFrame(Sock);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "server says hello"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("client close handshake")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send close frame
+ std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000);
+ asio::write(Sock, asio::buffer(CloseFrame));
+
+ // Server should echo close back
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_CloseCount.load(), 1);
+ CHECK_EQ(TestService.m_LastCloseCode.load(), 1000);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple concurrent connections")
+ {
+ constexpr int NumClients = 5;
+
+ asio::io_context IoCtx;
+ std::vector<asio::ip::tcp::socket> Sockets;
+
+ for (int i = 0; i < NumClients; ++i)
+ {
+ Sockets.emplace_back(IoCtx);
+ Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port);
+ REQUIRE(Ok);
+ }
+
+ Sleep(100);
+ CHECK_EQ(TestService.m_OpenCount.load(), NumClients);
+
+ // Broadcast from server
+ TestService.SendToAll("broadcast");
+
+ // Each client should receive the message
+ for (int i = 0; i < NumClients; ++i)
+ {
+ WsFrameParseResult Frame = ReadOneFrame(Sockets[i]);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "broadcast"sv);
+ }
+
+ // Close all
+ for (auto& S : Sockets)
+ {
+ S.close();
+ }
+ }
+
+ SUBCASE("service without IWebSocketHandler rejects upgrade")
+ {
+ // Register a plain HTTP service (no WebSocket)
+ struct PlainService : public HttpService
+ {
+ const char* BaseUri() const override { return "/plain/"; }
+ void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); }
+ };
+
+ PlainService Plain;
+ Server->RegisterService(Plain);
+
+ Sleep(50);
+
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Attempt WebSocket upgrade on the plain service
+ ExtendableStringBuilder<512> Request;
+ Request << "GET /plain/ws HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ // Should NOT get 101 — should fall through to normal request handling
+ CHECK(Response.find("101") == std::string::npos);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping/pong auto-response")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a ping frame with payload "test"
+ std::string_view PingPayload = "test";
+ std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size());
+ std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData);
+ asio::write(Sock, asio::buffer(PingFrame));
+
+ // Should receive a pong with the same payload
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong);
+ CHECK_EQ(Reply.Payload.size(), 4u);
+ std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(PongText, "test"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple messages in sequence")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ std::string Msg = fmt::format("message {}", i);
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg);
+ asio::write(Sock, asio::buffer(Frame));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, Msg);
+ }
+
+ CHECK_EQ(TestService.m_MessageCount.load(), 10);
+
+ Sock.close();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: HttpWsClient
+//
+
+namespace {
+
+ struct TestWsClientHandler : public IWsClientHandler
+ {
+ void OnWsOpen() override { m_OpenCount.fetch_add(1); }
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+ }
+ m_MessageCount.fetch_add(1);
+ }
+
+ void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+ };
+
+} // anonymous namespace
+
+TEST_CASE("websocket.client")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(7576, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello from client");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello from client");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ // The server echoes the close frame, which triggers OnWsClose on the client side
+ // with the server's close code. Allow the connection to settle.
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+
+ SUBCASE("connect to bad port")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = "ws://127.0.0.1:1/wstest/ws";
+
+ HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1006);
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ }
+
+ SUBCASE("server-initiated close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+
+ // Copy connections then close them outside the lock to avoid deadlocking
+ // with OnWebSocketClose which acquires an exclusive lock
+ std::vector<Ref<WebSocketConnection>> Conns;
+ TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; });
+ for (auto& Conn : Conns)
+ {
+ Conn->Close(1001, "going away");
+ }
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1001);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+TEST_SUITE_END();
+
+void
+websocket_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp
index 9135d5425..489324aba 100644
--- a/src/zenhttp/transports/dlltransport.cpp
+++ b/src/zenhttp/transports/dlltransport.cpp
@@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa
void
DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message)
{
- logging::level::LogLevel Level;
- // clang-format off
switch (PluginLogLevel)
{
- case LogLevel::Trace: Level = logging::level::Trace; break;
- case LogLevel::Debug: Level = logging::level::Debug; break;
- case LogLevel::Info: Level = logging::level::Info; break;
- case LogLevel::Warn: Level = logging::level::Warn; break;
- case LogLevel::Err: Level = logging::level::Err; break;
- case LogLevel::Critical: Level = logging::level::Critical; break;
- default: Level = logging::level::Off; break;
+ case LogLevel::Trace:
+ ZEN_TRACE("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Debug:
+ ZEN_DEBUG("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Info:
+ ZEN_INFO("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Warn:
+ ZEN_WARN("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Err:
+ ZEN_ERROR("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Critical:
+ ZEN_CRITICAL("[{}] {}", m_PluginName, Message);
+ return;
+
+ default:
+ ZEN_UNUSED(Message);
+ break;
}
- // clang-format on
- ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message)
}
uint32_t
diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp
index c06a50c95..0217ed44e 100644
--- a/src/zenhttp/transports/winsocktransport.cpp
+++ b/src/zenhttp/transports/winsocktransport.cpp
@@ -322,7 +322,7 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface)
else
{
}
- } while (!IsApplicationExitRequested() && m_KeepRunning.test());
+ } while (m_KeepRunning.test());
ZEN_INFO("HTTP plugin server accept thread exit");
});
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 78876d21b..e8f87b668 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -6,6 +6,7 @@ target('zenhttp')
add_headerfiles("**.h")
add_files("**.cpp")
add_files("servers/httpsys.cpp", {unity_ignored=true})
+ add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr")
add_packages("http_parser", "json11")
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index a2679f92e..3ac8eea8d 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -7,6 +7,7 @@
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/packageformat.h>
+# include <zenhttp/security/passwordsecurity.h>
namespace zen {
@@ -15,7 +16,10 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ httpclient_test_forcelink();
forcelink_packageformat();
+ passwordsecurity_forcelink();
+ websocket_forcelink();
}
} // namespace zen