aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-15 20:42:36 +0100
committerStefan Boberg <[email protected]>2026-03-15 20:42:36 +0100
commit9c724efbf6b38466a9b6bfde37236369f1e85cb8 (patch)
tree214e1ec00c5bfca0704ce52789017ade734fd054 /src/zenhttp
parentreduced WaitForThreads time to see how it behaves with explicit thread pools (diff)
parentadd buildid updates to oplog and builds test scripts (#838) (diff)
downloadzen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.tar.xz
zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.zip
Merge remote-tracking branch 'origin/main' into sb/threadpool
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/auth/oidc.cpp24
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp380
-rw-r--r--src/zenhttp/clients/httpclientcommon.h120
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp791
-rw-r--r--src/zenhttp/clients/httpclientcpr.h34
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp1816
-rw-r--r--src/zenhttp/clients/httpclientcurl.h137
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp641
-rw-r--r--src/zenhttp/httpclient.cpp222
-rw-r--r--src/zenhttp/httpclient_test.cpp1701
-rw-r--r--src/zenhttp/httpserver.cpp170
-rw-r--r--src/zenhttp/include/zenhttp/cprutils.h22
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h4
-rw-r--r--src/zenhttp/include/zenhttp/httpapiservice.h1
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h140
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h7
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h114
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h47
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h83
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h2
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h65
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp195
-rw-r--r--src/zenhttp/packageformat.cpp24
-rw-r--r--src/zenhttp/security/passwordsecurity.cpp5
-rw-r--r--src/zenhttp/servers/asio_socket_traits.h54
-rw-r--r--src/zenhttp/servers/httpasio.cpp917
-rw-r--r--src/zenhttp/servers/httpasio.h7
-rw-r--r--src/zenhttp/servers/httpmulti.cpp10
-rw-r--r--src/zenhttp/servers/httpmulti.h13
-rw-r--r--src/zenhttp/servers/httpparser.cpp156
-rw-r--r--src/zenhttp/servers/httpparser.h9
-rw-r--r--src/zenhttp/servers/httpplugin.cpp12
-rw-r--r--src/zenhttp/servers/httpsys.cpp786
-rw-r--r--src/zenhttp/servers/httpsys.h4
-rw-r--r--src/zenhttp/servers/httpsys_iocontext.h40
-rw-r--r--src/zenhttp/servers/httptracer.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp339
-rw-r--r--src/zenhttp/servers/wsasio.h94
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp236
-rw-r--r--src/zenhttp/servers/wsframecodec.h74
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp485
-rw-r--r--src/zenhttp/servers/wshttpsys.h107
-rw-r--r--src/zenhttp/servers/wstest.cpp994
-rw-r--r--src/zenhttp/transports/asiotransport.cpp14
-rw-r--r--src/zenhttp/transports/dlltransport.cpp38
-rw-r--r--src/zenhttp/xmake.lua13
-rw-r--r--src/zenhttp/zenhttp.cpp2
47 files changed, 10274 insertions, 879 deletions
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp
index 38e7586ad..23bbc17e8 100644
--- a/src/zenhttp/auth/oidc.cpp
+++ b/src/zenhttp/auth/oidc.cpp
@@ -32,6 +32,25 @@ namespace details {
using namespace std::literals;
+static std::string
+FormUrlEncode(std::string_view Input)
+{
+ std::string Result;
+ Result.reserve(Input.size());
+ for (char C : Input)
+ {
+ if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~')
+ {
+ Result.push_back(C);
+ }
+ else
+ {
+ Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C)));
+ }
+ }
+ return Result;
+}
+
OidcClient::OidcClient(const OidcClient::Options& Options)
{
m_BaseUrl = std::string(Options.BaseUrl);
@@ -67,6 +86,8 @@ OidcClient::Initialize()
.TokenEndpoint = Json["token_endpoint"].string_value(),
.UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
.RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .EndSessionEndpoint = Json["end_session_endpoint"].string_value(),
+ .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(),
.JwksUri = Json["jwks_uri"].string_value(),
.SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
.SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
@@ -81,7 +102,8 @@ OidcClient::Initialize()
OidcClient::RefreshTokenResult
OidcClient::RefreshToken(std::string_view RefreshToken)
{
- const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+ const std::string Body =
+ fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId));
HttpClient Http{m_Config.TokenEndpoint};
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 47425e014..e4d11547a 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -142,7 +142,10 @@ namespace detail {
DataSize -= CopySize;
if (m_CacheBufferOffset == CacheBufferSize)
{
- AppendData(m_CacheBuffer, CacheBufferSize);
+ if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize))
+ {
+ return Ec;
+ }
if (DataSize > 0)
{
ZEN_ASSERT(DataSize < CacheBufferSize);
@@ -382,6 +385,177 @@ namespace detail {
return Result;
}
+ MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {}
+
+ bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue)
+ {
+ std::string LowerCaseValue = ToLower(ContentTypeHeaderValue);
+ if (LowerCaseValue.starts_with("multipart/byteranges"))
+ {
+ size_t BoundaryPos = LowerCaseValue.find("boundary=");
+ if (BoundaryPos != std::string::npos)
+ {
+ // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string
+ std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9);
+ size_t BoundaryEnd = std::string::npos;
+ while (!BoundaryName.empty() && BoundaryName[0] == ' ')
+ {
+ BoundaryName = BoundaryName.substr(1);
+ }
+ if (!BoundaryName.empty())
+ {
+ if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"')
+ {
+ BoundaryEnd = BoundaryName.find('"', 1);
+ if (BoundaryEnd != std::string::npos)
+ {
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1)));
+ return true;
+ }
+ }
+ else
+ {
+ BoundaryEnd = BoundaryName.find_first_of(" \r\n");
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd)));
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ void MultipartBoundaryParser::ParseInput(std::string_view data)
+ {
+ const char* InputPtr = data.data();
+ size_t InputLength = data.length();
+ size_t ScanPos = 0;
+ while (ScanPos < InputLength)
+ {
+ const char ScanChar = InputPtr[ScanPos];
+ if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length()))
+ {
+ BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ break;
+ }
+ }
+
+ BoundaryHeader.Append(ScanChar);
+
+ HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+
+ if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset();
+ const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset();
+ const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset;
+ std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength));
+
+ uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1;
+
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType = HttpContentType::kBinary;
+
+ ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) {
+ const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line);
+ const std::string_view Key = KeyAndValue.first;
+ const std::string_view Value = KeyAndValue.second;
+ if (Key == "Content-Range")
+ {
+ std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value);
+ if (ContentRange.second != 0)
+ {
+ RangeOffset = ContentRange.first;
+ RangeLength = ContentRange.second;
+ }
+ }
+ else if (Key == "Content-Type")
+ {
+ ContentType = ParseContentType(Value);
+ }
+
+ return true;
+ });
+
+ if (RangeLength > 0)
+ {
+ Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload,
+ .RangeOffset = RangeOffset,
+ .RangeLength = RangeLength,
+ .ContentType = ContentType});
+ }
+
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ }
+ }
+ else
+ {
+ BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ }
+ ScanPos++;
+ }
+ PayloadOffset += InputLength;
+ }
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString)
+ {
+ size_t DelimiterPos = HeaderString.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string_view Key = HeaderString.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string_view Value = HeaderString.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(HeaderString, std::string_view{});
+ }
+
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value)
+ {
+ if (Value.starts_with("bytes "))
+ {
+ size_t RangeSplitPos = Value.find('-', 6);
+ if (RangeSplitPos != std::string::npos)
+ {
+ size_t RangeEndLength = Value.find('/', RangeSplitPos + 1);
+ if (RangeEndLength == std::string::npos)
+ {
+ RangeEndLength = Value.length() - (RangeSplitPos + 1);
+ }
+ else
+ {
+ RangeEndLength = RangeEndLength - (RangeSplitPos + 1);
+ }
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ uint64_t RangeOffset = RequestedRangeStart.value();
+ uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1;
+ return std::make_pair(RangeOffset, RangeLength);
+ }
+ }
+ }
+ return {0, 0};
+ }
+
} // namespace detail
} // namespace zen
@@ -423,6 +597,8 @@ namespace testutil {
} // namespace testutil
+TEST_SUITE_BEGIN("http.httpclientcommon");
+
TEST_CASE("BufferedReadFileStream")
{
ScopedTemporaryDirectory TmpDir;
@@ -470,5 +646,207 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("ParseContentRange")
+{
+ SUBCASE("normal range with total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500");
+ CHECK_EQ(Offset, 0);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("non-zero offset")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878");
+ CHECK_EQ(Offset, 2638);
+ CHECK_EQ(Length, 5111437 - 2638 + 1);
+ }
+
+ SUBCASE("wildcard total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*");
+ CHECK_EQ(Offset, 100);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("no slash (total size omitted)")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 50-149");
+ CHECK_EQ(Offset, 50);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("malformed input returns zeros")
+ {
+ auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500");
+ CHECK_EQ(Offset1, 0);
+ CHECK_EQ(Length1, 0);
+
+ auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500");
+ CHECK_EQ(Offset2, 0);
+ CHECK_EQ(Length2, 0);
+
+ auto [Offset3, Length3] = detail::ParseContentRange("");
+ CHECK_EQ(Offset3, 0);
+ CHECK_EQ(Length3, 0);
+
+ auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500");
+ CHECK_EQ(Offset4, 0);
+ CHECK_EQ(Length4, 0);
+ }
+
+ SUBCASE("single byte range")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000");
+ CHECK_EQ(Offset, 42);
+ CHECK_EQ(Length, 1);
+ }
+}
+
+TEST_CASE("MultipartBoundaryParser")
+{
+ uint64_t Range1Offset = 2638;
+ uint64_t Range1Length = (5111437 - Range1Offset) + 1;
+
+ uint64_t Range2Offset = 5118199;
+ uint64_t Range2Length = (9147741 - Range2Offset) + 1;
+
+ std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229";
+ std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\"";
+
+ {
+ std::string_view Example1 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/44369878\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample1;
+ ParserExample1.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 7;
+ for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow)
+ {
+ ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample1.Boundaries.size() == 2);
+
+ CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example2 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample2;
+ ParserExample2.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow));
+ ParserExample2.ParseInput(Window);
+ }
+
+ CHECK(ParserExample2.Boundaries.size() == 2);
+
+ CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example3 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita";
+
+ detail::MultipartBoundaryParser ParserExample3;
+ ParserExample3.Init(ContentTypeHeaderValue2);
+
+ const size_t InputWindow = 31;
+ for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow)
+ {
+ ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample3.Boundaries.size() == 2);
+
+ CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example4 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "Not: really\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--000000000bait0019229\r\n"
+ "\r\n--00\r\n--000000000bait001922\r\n"
+ "\r\n\r\n\r\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "Content-Type: application/x-ue-comp\r\n"
+ "ditaditadita"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n---\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample4;
+ ParserExample4.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow));
+ ParserExample4.ParseInput(Window);
+ }
+
+ CHECK(ParserExample4.Boundaries.size() == 2);
+
+ CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length);
+ }
+}
+
+TEST_SUITE_END();
+
} // namespace zen
#endif
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index 1d0b7f9ea..e95e3a253 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhttp/httpclient.h>
@@ -35,7 +36,10 @@ public:
const IoBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {}) = 0;
- [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) = 0;
[[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
[[nodiscard]] virtual Response Post(std::string_view Url,
const CompositeBuffer& Payload,
@@ -87,7 +91,7 @@ namespace detail {
std::error_code Write(std::string_view DataString);
IoBuffer DetachToIoBuffer();
IoBuffer BorrowIoBuffer();
- inline uint64_t GetSize() const { return m_WriteOffset; }
+ inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; }
void ResetWritePos(uint64_t WriteOffset);
private:
@@ -143,6 +147,118 @@ namespace detail {
uint64_t m_BytesLeftInSegment;
};
+ class IncrementalStringMatcher
+ {
+ public:
+ enum class EMatchState
+ {
+ None,
+ Partial,
+ Complete
+ };
+
+ EMatchState MatchState = EMatchState::None;
+
+ IncrementalStringMatcher() {}
+
+ IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString))
+ {
+ RawMatchString = MatchString.data();
+ }
+
+ void Init(std::string&& InMatchString)
+ {
+ MatchString = std::move(InMatchString);
+ RawMatchString = MatchString.data();
+ }
+
+ inline void Reset()
+ {
+ MatchLength = 0;
+ MatchStartOffset = 0;
+ MatchState = EMatchState::None;
+ }
+
+ inline uint64_t GetMatchEndOffset() const
+ {
+ if (MatchState == EMatchState::Complete)
+ {
+ return MatchStartOffset + MatchString.length();
+ }
+ return 0;
+ }
+
+ inline uint64_t GetMatchStartOffset() const
+ {
+ ZEN_ASSERT(MatchState == EMatchState::Complete);
+ return MatchStartOffset;
+ }
+
+ void Match(uint64_t Offset, char C)
+ {
+ ZEN_ASSERT_SLOW(RawMatchString != nullptr);
+
+ if (MatchState == EMatchState::Complete)
+ {
+ Reset();
+ }
+ if (C == RawMatchString[MatchLength])
+ {
+ if (MatchLength == 0)
+ {
+ MatchStartOffset = Offset;
+ }
+ MatchLength++;
+ if (MatchLength == MatchString.length())
+ {
+ MatchState = EMatchState::Complete;
+ }
+ else
+ {
+ MatchState = EMatchState::Partial;
+ }
+ }
+ else if (MatchLength != 0)
+ {
+ Reset();
+ Match(Offset, C);
+ }
+ else
+ {
+ Reset();
+ }
+ }
+ inline const std::string& GetMatchString() const { return MatchString; }
+
+ private:
+ std::string MatchString;
+ const char* RawMatchString = nullptr;
+ uint64_t MatchLength = 0;
+
+ uint64_t MatchStartOffset = 0;
+ };
+
+ class MultipartBoundaryParser
+ {
+ public:
+ std::vector<HttpClient::Response::MultipartBoundary> Boundaries;
+
+ MultipartBoundaryParser();
+ bool Init(const std::string_view ContentTypeHeaderValue);
+ void ParseInput(std::string_view data);
+
+ private:
+ IncrementalStringMatcher BoundaryBeginMatcher;
+ IncrementalStringMatcher BoundaryEndMatcher;
+ IncrementalStringMatcher HeaderEndMatcher;
+
+ ExtendableStringBuilder<64> BoundaryHeader;
+ uint64_t PayloadOffset = 0;
+ };
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString);
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value);
+
} // namespace detail
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 5d92b3b6b..a52b8f74b 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -7,11 +7,18 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/compactbinaryutil.h>
#include <zencore/compress.h>
+#include <zencore/filesystem.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
#include <zencore/session.h>
#include <zencore/stream.h>
#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/ssl_options.h>
+#include <cpr/unix_socket.h>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
@@ -23,69 +30,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
-// If we want to support different HTTP client implementations then we'll need to make this more abstract
+//////////////////////////////////////////////////////////////////////////
-HttpClientError::ResponseClass
-HttpClientError::GetResponseClass() const
+static HttpClientErrorCode
+MapCprError(cpr::ErrorCode Code)
{
- if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
- {
- switch ((cpr::ErrorCode)m_Error)
- {
- case cpr::ErrorCode::CONNECTION_FAILURE:
- return ResponseClass::kHttpCantConnectError;
- case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
- case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return ResponseClass::kHttpNoHost;
- case cpr::ErrorCode::INTERNAL_ERROR:
- case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
- case cpr::ErrorCode::NETWORK_SEND_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
- return ResponseClass::kHttpTimeout;
- case cpr::ErrorCode::SSL_CONNECT_ERROR:
- case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_CACERT_ERROR:
- case cpr::ErrorCode::GENERIC_SSL_ERROR:
- return ResponseClass::kHttpSLLError;
- default:
- return ResponseClass::kHttpOtherClientError;
- }
- }
- else if (IsHttpSuccessCode(m_ResponseCode))
+ switch (Code)
{
- return ResponseClass::kSuccess;
- }
- else
- {
- switch (m_ResponseCode)
- {
- case HttpResponseCode::Unauthorized:
- return ResponseClass::kHttpUnauthorized;
- case HttpResponseCode::NotFound:
- return ResponseClass::kHttpNotFound;
- case HttpResponseCode::Forbidden:
- return ResponseClass::kHttpForbidden;
- case HttpResponseCode::Conflict:
- return ResponseClass::kHttpConflict;
- case HttpResponseCode::InternalServerError:
- return ResponseClass::kHttpInternalServerError;
- case HttpResponseCode::ServiceUnavailable:
- return ResponseClass::kHttpServiceUnavailable;
- case HttpResponseCode::BadGateway:
- return ResponseClass::kHttpBadGateway;
- case HttpResponseCode::GatewayTimeout:
- return ResponseClass::kHttpGatewayTimeout;
- default:
- if (m_ResponseCode >= HttpResponseCode::InternalServerError)
- {
- return ResponseClass::kHttpOtherServerError;
- }
- else
- {
- return ResponseClass::kHttpOtherClientError;
- }
- }
+ case cpr::ErrorCode::OK:
+ return HttpClientErrorCode::kOK;
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ return HttpClientErrorCode::kConnectionFailure;
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ return HttpClientErrorCode::kInternalError;
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ return HttpClientErrorCode::kSSLCACertError;
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return HttpClientErrorCode::kGenericSSLError;
+ case cpr::ErrorCode::REQUEST_CANCELLED:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
}
}
@@ -149,6 +129,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri,
{
}
+bool
+CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ // Quiet
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
CprHttpClient::~CprHttpClient()
{
ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
@@ -162,10 +154,11 @@ CprHttpClient::~CprHttpClient()
}
HttpClient::Response
-CprHttpClient::ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload)
+CprHttpClient::ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
// This ends up doing a memcpy, would be good to get rid of it by streaming results
// into buffer directly
@@ -174,30 +167,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId,
if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
{
const HttpContentType ContentType = ParseContentType(It->second);
-
ResponseBuffer.SetContentType(ContentType);
}
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
-
- if (!Quiet)
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
{
- if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ if (ShouldLogErrorCode(WorkResponseCode))
{
ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
}
}
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
return HttpClient::Response{.StatusCode = WorkResponseCode,
.ResponsePayload = std::move(ResponseBuffer),
.Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Ranges = std::move(BoundaryPositions)};
}
HttpClient::Response
-CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload)
+CprHttpClient::CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
if (HttpResponse.error)
@@ -221,8 +221,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
.ElapsedSeconds = HttpResponse.elapsed,
- .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
- .ErrorMessage = HttpResponse.error.message}};
+ .Error =
+ HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}};
}
if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
@@ -235,7 +235,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
}
else
{
- return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload));
+ return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
}
}
@@ -346,8 +346,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -385,8 +384,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -492,6 +490,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
}
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ CprSession->UpdateHeader({{"Connection", "close"}});
+ }
if (AccessToken)
{
CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
@@ -510,6 +512,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
CprSession->SetParameters({});
}
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath)));
+ }
+
+ if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty())
+ {
+ cpr::SslOptions SslOpts;
+ if (ConnectionSettings.InsecureSsl)
+ {
+ SslOpts.SetOption(cpr::ssl::VerifyHost{false});
+ SslOpts.SetOption(cpr::ssl::VerifyPeer{false});
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath});
+ }
+ CprSession->SetSslOptions(SslOpts);
+ }
+
ExtendableStringBuilder<128> UrlBuffer;
UrlBuffer << BaseUrl << ResourcePath;
CprSession->SetUrl(UrlBuffer.c_str());
@@ -621,7 +643,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke
ResponseBuffer.SetContentType(ContentType);
}
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)};
}
//////////////////////////////////////////////////////////////////////////
@@ -774,22 +796,97 @@ CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTyp
}
CprHttpClient::Response
-CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+CprHttpClient::Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader,
+ const std::filesystem::path& TempFolderPath)
{
ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload");
- return CommonResponse(
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ cpr::Response Response = DoWithRetry(
m_SessionId,
- DoWithRetry(m_SessionId,
- [&]() {
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ [&]() {
+ PayloadString.clear();
+ PayloadFile.reset();
- Sess->SetBody(AsCprBody(Payload));
- Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
- return Sess.Post();
- }),
- {});
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+ cpr::Response Response = Sess.Post({}, cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+ return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
}
CprHttpClient::Response
@@ -896,236 +993,292 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF
std::string PayloadString;
std::unique_ptr<detail::TempPayloadFile> PayloadFile;
- cpr::Response Response = DoWithRetry(
- m_SessionId,
- [&]() {
- auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
- size_t DelimiterPos = header.find(':');
- if (DelimiterPos != std::string::npos)
- {
- std::string Key = header.substr(0, DelimiterPos);
- constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
- Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
- Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
-
- std::string Value = header.substr(DelimiterPos + 1);
- Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
- Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
-
- return std::make_pair(Key, Value);
- }
- return std::make_pair(header, "");
- };
-
- auto DownloadCallback = [&](std::string data, intptr_t) {
- if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
- {
- return false;
- }
- if (PayloadFile)
- {
- ZEN_ASSERT(PayloadString.empty());
- std::error_code Ec = PayloadFile->Write(data);
- if (Ec)
- {
- ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- return false;
- }
- }
- else
- {
- PayloadString.append(data);
- }
- return true;
- };
-
- uint64_t RequestedContentLength = (uint64_t)-1;
- if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
- {
- if (RangeIt->second.starts_with("bytes"))
- {
- size_t RangeStartPos = RangeIt->second.find('=', 5);
- if (RangeStartPos != std::string::npos)
- {
- RangeStartPos++;
- size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
- if (RangeSplitPos != std::string::npos)
- {
- std::optional<size_t> RequestedRangeStart =
- ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
- std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
- if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
- {
- RequestedContentLength = RequestedRangeEnd.value() - 1;
- }
- }
- }
- }
- }
-
- cpr::Response Response;
- {
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (Header.first == "Content-Length"sv)
- {
- std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
- if (ContentLength.has_value())
- {
- if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
- {
- PayloadFile = std::make_unique<detail::TempPayloadFile>();
- std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
- if (Ec)
- {
- ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- PayloadFile.reset();
- }
- }
- else
- {
- PayloadString.reserve(ContentLength.value());
- }
- }
- }
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
- return 1;
- };
-
- Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- }
- if (m_ConnectionSettings.AllowResume)
- {
- auto SupportsRanges = [](const cpr::Response& Response) -> bool {
- if (Response.header.find("Content-Range") != Response.header.end())
- {
- return true;
- }
- if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
- {
- return It->second == "bytes"sv;
- }
- return false;
- };
-
- auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
- if (ShouldRetry(Response))
- {
- return SupportsRanges(Response);
- }
- return false;
- };
-
- if (ShouldResume(Response))
- {
- auto It = Response.header.find("Content-Length");
- if (It != Response.header.end())
- {
- uint64_t ContentLength = RequestedContentLength;
- if (ContentLength == uint64_t(-1))
- {
- if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
- {
- ContentLength = ParsedContentLength.value();
- }
- }
-
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
-
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
-
- if (Header.first == "Content-Range"sv)
- {
- if (Header.second.starts_with("bytes "sv))
- {
- size_t RangeStartEnd = Header.second.find('-', 6);
- if (RangeStartEnd != std::string::npos)
- {
- const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
- if (Start)
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
- if (Start.value() == DownloadedSize)
- {
- return 1;
- }
- else if (Start.value() > DownloadedSize)
- {
- return 0;
- }
- if (PayloadFile)
- {
- PayloadFile->ResetWritePos(Start.value());
- }
- else
- {
- PayloadString = PayloadString.substr(0, Start.value());
- }
- return 1;
- }
- }
- }
- return 0;
- }
- return 1;
- };
-
- KeyValueMap HeadersWithRange(AdditionalHeader);
- do
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
-
- std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
- if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
- {
- if (RangeIt->second == Range)
- {
- // If we didn't make any progress, abort
- break;
- }
- }
- HeadersWithRange.Entries.insert_or_assign("Range", Range);
-
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- ReceivedHeaders.clear();
- } while (ShouldResume(Response));
- }
- }
- }
-
- if (!PayloadString.empty())
- {
- Response.text = std::move(PayloadString);
- }
- return Response;
- },
- PayloadFile);
- return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (IsMultiRangeResponse)
+ {
+ BoundaryParser.ParseInput(data);
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ if (RequestedContentLength != (uint64_t)-1 && RequestedContentLength > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ ZEN_DEBUG("Multirange request");
+ }
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ const std::string Key(Header.first);
+ if (StrCaseCompare(Key.c_str(), "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0)
+ {
+ IsMultiRangeResponse = BoundaryParser.Init(Header.second);
+ if (!IsMultiRangeResponse)
+ {
+ ContentType = ParseContentType(Header.second);
+ }
+ }
+ else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0)
+ {
+ if (!IsMultiRangeResponse)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second);
+ if (Range.second != 0)
+ {
+ BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = ContentType});
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Range") == 0)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Response),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
}
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 40af53b5d..009e6fb7a 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -38,7 +38,10 @@ public:
const IoBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {}) override;
- [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) override;
[[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
[[nodiscard]] virtual Response Post(std::string_view Url,
const CompositeBuffer& Payload,
@@ -104,15 +107,27 @@ private:
CprSession->SetReadCallback({});
return Result;
}
- inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {})
+ inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {},
+ std::optional<cpr::WriteCallback>&& Write = {},
+ std::optional<cpr::HeaderCallback>&& Header = {})
{
ZEN_TRACE_CPU("HttpClient::Impl::Post");
if (Read)
{
CprSession->SetReadCallback(std::move(Read.value()));
}
+ if (Write)
+ {
+ CprSession->SetWriteCallback(std::move(Write.value()));
+ }
+ if (Header)
+ {
+ CprSession->SetHeaderCallback(std::move(Header.value()));
+ }
cpr::Response Result = CprSession->Post();
ZEN_TRACE("POST {}", Result);
+ CprSession->SetHeaderCallback({});
+ CprSession->SetWriteCallback({});
CprSession->SetReadCallback({});
return Result;
}
@@ -155,14 +170,19 @@ private:
std::function<cpr::Response()>&& Func,
std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; });
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
- HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload);
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
- HttpClient::Response ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload);
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
};
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
new file mode 100644
index 000000000..ec9b7bac6
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -0,0 +1,1816 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcurl.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+namespace zen {
+
+HttpClientBase*
+CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
+{
+ return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+}
+
+static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback helpers
+
+struct WriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<WriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct HeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
+// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
+static std::optional<std::pair<std::string_view, std::string_view>>
+ParseHeaderLine(std::string_view Line)
+{
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return std::nullopt;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ return std::nullopt;
+ }
+
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ return std::pair{Key, Value};
+}
+
+static size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<HeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct ReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<ReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+struct StreamReadCallbackData
+{
+ detail::CompositeBufferReadStream* Reader = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<StreamReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ return Data->Reader->Read(Buffer, MaxRead);
+}
+
+struct FileReadCallbackData
+{
+ detail::BufferedReadFileStream* Buffer = nullptr;
+ uint64_t TotalSize = 0;
+ uint64_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<FileReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->TotalSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ Data->Buffer->Read(Buffer, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+static int
+CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr)
+{
+ ZEN_UNUSED(Handle);
+ LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr);
+ auto Log = [&]() -> LoggerRef { return LogRef; };
+
+ std::string_view DataView(Data, Size);
+
+ // Trim trailing newlines
+ while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n'))
+ {
+ DataView.remove_suffix(1);
+ }
+
+ switch (Type)
+ {
+ case CURLINFO_TEXT:
+ if (DataView.find("need more data"sv) == std::string_view::npos)
+ {
+ ZEN_INFO("TEXT: {}", DataView);
+ }
+ break;
+ case CURLINFO_HEADER_IN:
+ ZEN_INFO("HIN : {}", DataView);
+ break;
+ case CURLINFO_HEADER_OUT:
+ if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos)
+ {
+ std::string Copy(DataView);
+ auto BearerStart = TokenPos + 22;
+ auto BearerEnd = Copy.find_first_of("\r\n", BearerStart);
+ if (BearerEnd == std::string::npos)
+ {
+ BearerEnd = Copy.length();
+ }
+ Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart));
+ ZEN_INFO("HOUT: {}", Copy);
+ }
+ else
+ {
+ ZEN_INFO("HOUT: {}", DataView);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+static curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ ExtendableStringBuilder<64> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ ExtendableStringBuilder<64> SessionHeader;
+ SessionHeader << "UE-Session: " << SessionId;
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken)
+ {
+ ExtendableStringBuilder<128> AuthHeader;
+ AuthHeader << "Authorization: " << AccessToken->Value;
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ ExtendableStringBuilder<128> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+static HttpClient::KeyValueMap
+BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+ return HeaderMap;
+}
+
+// Scans response headers for Content-Type and applies it to the buffer.
+static void
+ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ for (const auto& [Key, Value] : Headers)
+ {
+ if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ Buffer.SetContentType(ParseContentType(Value));
+ break;
+ }
+ }
+}
+
+static void
+AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
+{
+ static constexpr char HexDigits[] = "0123456789ABCDEF";
+ static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
+
+ for (char C : Input)
+ {
+ if (Unreserved.Contains(C))
+ {
+ Out.Append(C);
+ }
+ else
+ {
+ uint8_t Byte = static_cast<uint8_t>(C);
+ char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
+ Out.Append(std::string_view(Encoded, 3));
+ }
+ }
+}
+
+static void
+BuildUrlWithParameters(StringBuilderBase& Url,
+ std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClient::KeyValueMap& Parameters)
+{
+ Url.Append(BaseUrl);
+ Url.Append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ Url.Append(Separator);
+ AppendUrlEncoded(Url, Key);
+ Url.Append('=');
+ AppendUrlEncoded(Url, Value);
+ Separator = '&';
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction)
+: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction))
+{
+}
+
+CurlHttpClient::~CurlHttpClient()
+{
+ ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto* Handle : m_Sessions)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_Sessions.clear();
+ });
+}
+
+CurlHttpClient::Session::~Session()
+{
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ Outer->ReleaseSession(Handle);
+}
+
+void
+CurlHttpClient::Session::SetHeaders(curl_slist* Headers)
+{
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ HeaderList = Headers;
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList);
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::PerformWithResponseCallbacks()
+{
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body,
+ .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::Perform()
+{
+ CurlResult Result;
+
+ char ErrorBuffer[CURL_ERROR_SIZE] = {};
+ curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer);
+
+ Result.ErrorCode = curl_easy_perform(Handle);
+
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode);
+ }
+
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+ Result.ElapsedSeconds = Elapsed;
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+ Result.UploadedBytes = static_cast<int64_t>(UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+ Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+
+ return Result;
+}
+
+bool
+CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
+HttpClient::Response
+CurlHttpClient::ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size());
+
+ ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers);
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ if (ShouldLogErrorCode(WorkResponseCode))
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}",
+ SessionId,
+ static_cast<int>(WorkResponseCode),
+ m_BaseUri);
+ }
+ }
+
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Ranges = std::move(BoundaryPositions)};
+}
+
+HttpClient::Response
+CurlHttpClient::CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode);
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
+ if (!Quiet)
+ {
+ if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
+ Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ SessionId,
+ static_cast<int>(Result.ErrorCode),
+ Result.ErrorMessage);
+ }
+ }
+
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()),
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload))
+ {
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ else
+ {
+ return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
+ }
+}
+
+bool
+CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+
+ IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+
+ // Collect relevant headers in a single pass
+ std::string_view ContentLengthValue;
+ std::string_view IoHashValue;
+ std::string_view ContentTypeValue;
+
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ ContentLengthValue = Value;
+ }
+ else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0)
+ {
+ IoHashValue = Value;
+ }
+ else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ ContentTypeValue = Value;
+ }
+ }
+
+ // Validate Content-Length
+ if (!ContentLengthValue.empty())
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue);
+ if (!ExpectedContentSize.has_value())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue);
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage =
+ fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue);
+ return false;
+ }
+ }
+
+ if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent))
+ {
+ return true;
+ }
+
+ // Validate X-Jupiter-IoHash
+ if (!IoHashValue.empty())
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString());
+ return false;
+ }
+ }
+ }
+
+ // Validate content-type specific payload
+ if (ContentTypeValue == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer,
+ RawHash,
+ RawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr))
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = "Compressed binary failed validation";
+ return false;
+ }
+ }
+ if (ContentTypeValue == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error));
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool
+CurlHttpClient::ShouldRetry(const CurlResult& Result)
+{
+ switch (Result.ErrorCode)
+ {
+ case CURLE_OK:
+ break;
+ case CURLE_RECV_ERROR:
+ case CURLE_SEND_ERROR:
+ case CURLE_OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch (static_cast<HttpResponseCode>(Result.StatusCode))
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}",
+ SessionId,
+ static_cast<int>(MapCurlError(Result.ErrorCode)),
+ Result.ErrorMessage,
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ else
+ {
+ ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}",
+ SessionId,
+ Result.StatusCode,
+ zen::ToString(HttpResponseCode(Result.StatusCode)),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Session
+CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::AllocSession");
+ CURL* Handle = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ Handle = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (Handle == nullptr)
+ {
+ Handle = curl_easy_init();
+ if (Handle == nullptr)
+ {
+ ThrowOutOfMemory("curl_easy_init");
+ }
+ }
+ else
+ {
+ curl_easy_reset(Handle);
+ }
+
+ // Unix domain socket
+ if (!m_ConnectionSettings.UnixSocketPath.empty())
+ {
+ std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath);
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str());
+ }
+
+ // Build URL with parameters
+ ExtendableStringBuilder<256> Url;
+ BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Timeouts
+ if (m_ConnectionSettings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count()));
+ }
+ if (m_ConnectionSettings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (m_ConnectionSettings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // Verbose/debug
+ if (m_ConnectionSettings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log);
+ }
+
+ // SSL options
+ if (m_ConnectionSettings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!m_ConnectionSettings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str());
+ }
+
+ // Disable signal handling for thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (m_ConnectionSettings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+
+ // Note: Headers are NOT set here. Each method builds its own header list
+ // (potentially adding method-specific headers like Content-Type) and passes
+ // ownership to the Session via SetHeaders().
+
+ return Session(this, Handle);
+}
+
+void
+CurlHttpClient::ReleaseSession(CURL* Handle)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession");
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+// TransactPackage is a two-phase protocol (offer + send) with server-side state
+// between phases, so retrying individual phases would be incorrect.
+CurlHttpClient::Response
+CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::TransactPackage");
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++CurlHttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (!Attachments.empty())
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ std::vector<std::pair<std::string, std::string>> OfferExtraHeaders;
+ OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer));
+ OfferExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders));
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size()));
+
+ CurlResult Result = Sess.PerformWithResponseCallbacks();
+
+ if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode))
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ std::vector<std::pair<std::string, std::string>> PkgExtraHeaders;
+ PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage));
+ PkgExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders));
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize()));
+
+ CurlResult Result = Sess.PerformWithResponseCallbacks();
+
+ return CommonResponse(m_SessionId, std::move(Result), {}, {});
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Session Sess = AllocSession(Url, Parameters);
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Get");
+ return CommonResponse(m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, Parameters);
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L);
+ return Sess.PerformWithResponseCallbacks();
+ },
+ [this](CurlResult& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Head");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L);
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Delete");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE");
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, Parameters);
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L);
+ curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L);
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader,
+ const std::filesystem::path& TempFolderPath)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ PayloadString.clear();
+ PayloadFile.reset();
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}));
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize()));
+
+ struct PostHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ };
+
+ PostHeaderCallbackData PostHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ PostHdrData.Headers = &ResponseHeaders;
+ PostHdrData.PayloadFile = &PayloadFile;
+ PostHdrData.PayloadString = &PayloadString;
+ PostHdrData.TempFolderPath = &TempFolderPath;
+ PostHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ PostHdrData.Log = m_Log;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<PostHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+
+ if (StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (!Data->TempFolderPath->empty() && ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &PostHdrData);
+
+ struct PostWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ };
+
+ PostWriteCallbackData PostWriteData;
+ PostWriteData.PayloadString = &PayloadString;
+ PostWriteData.PayloadFile = &PayloadFile;
+ PostWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ PostWriteData.TempFolderPath = &TempFolderPath;
+ PostWriteData.Log = m_Log;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<PostWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &PostWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId, std::move(Result), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ // Track requested content length from Range header (sum all ranges)
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string_view::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string_view::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string_view::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ // Header callback that detects Content-Length and switches to file-backed storage when needed
+ struct DownloadHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ HttpContentType* ContentTypeOut = nullptr;
+ };
+
+ DownloadHeaderCallbackData DlHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ DlHdrData.Headers = &ResponseHeaders;
+ DlHdrData.PayloadFile = &PayloadFile;
+ DlHdrData.PayloadString = &PayloadString;
+ DlHdrData.TempFolderPath = &TempFolderPath;
+ DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ DlHdrData.Log = m_Log;
+ DlHdrData.BoundaryParser = &BoundaryParser;
+ DlHdrData.IsMultiRange = &IsMultiRangeResponse;
+ DlHdrData.ContentTypeOut = &ContentType;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [KeyView, Value] = *Header;
+ const std::string Key(KeyView);
+
+ if (StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ *Data->IsMultiRange = Data->BoundaryParser->Init(Value);
+ if (!*Data->IsMultiRange)
+ {
+ *Data->ContentTypeOut = ParseContentType(Value);
+ }
+ }
+ else if (StrCaseCompare(Key, "Content-Range") == 0)
+ {
+ if (!*Data->IsMultiRange)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value);
+ if (Range.second != 0)
+ {
+ Data->BoundaryParser->Boundaries.push_back(
+ HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = *Data->ContentTypeOut});
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(Key, std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData);
+
+ // Write callback that directs data to file or string
+ struct DownloadWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ };
+
+ DownloadWriteCallbackData DlWriteData;
+ DlWriteData.PayloadString = &PayloadString;
+ DlWriteData.PayloadFile = &PayloadFile;
+ DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ DlWriteData.TempFolderPath = &TempFolderPath;
+ DlWriteData.Log = m_Log;
+ DlWriteData.BoundaryParser = &BoundaryParser;
+ DlWriteData.IsMultiRange = &IsMultiRangeResponse;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->IsMultiRange)
+ {
+ Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes));
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ // Handle resume logic
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const CurlResult& R) -> bool {
+ for (const auto& [K, V] : R.Headers)
+ {
+ if (StrCaseCompare(K, "Content-Range") == 0)
+ {
+ return true;
+ }
+ if (StrCaseCompare(K, "Accept-Ranges") == 0)
+ {
+ return V == "bytes"sv;
+ }
+ }
+ return false;
+ };
+
+ auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(R))
+ {
+ return SupportsRanges(R);
+ }
+ return false;
+ };
+
+ if (ShouldResumeCheck(Res))
+ {
+ // Find Content-Length
+ std::string ContentLengthValue;
+ for (const auto& [K, V] : Res.Headers)
+ {
+ if (StrCaseCompare(K, "Content-Length") == 0)
+ {
+ ContentLengthValue = V;
+ break;
+ }
+ }
+
+ if (!ContentLengthValue.empty())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ uint8_t ResumeAttempt = 0;
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ break; // No progress, abort
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session ResumeSess = AllocSession(Url, {});
+ CURL* ResumeH = ResumeSess.Get();
+
+ ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L);
+
+ std::vector<std::pair<std::string, std::string>> ResumeHeaders;
+
+ struct ResumeHeaderCbData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ };
+
+ ResumeHeaderCbData ResumeHdrData;
+ ResumeHdrData.Headers = &ResumeHeaders;
+ ResumeHdrData.PayloadFile = &PayloadFile;
+ ResumeHdrData.PayloadString = &PayloadString;
+
+ auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<ResumeHeaderCbData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes));
+ if (!Header)
+ {
+ return TotalBytes;
+ }
+ auto& [Key, Value] = *Header;
+
+ if (StrCaseCompare(Key, "Content-Range") == 0)
+ {
+ if (Value.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Value.find('-', 6);
+ if (RangeStartEnd != std::string_view::npos)
+ {
+ const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize =
+ *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length();
+ if (Start.value() == DownloadedSize)
+ {
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (*Data->PayloadFile)
+ {
+ (*Data->PayloadFile)->ResetWritePos(Start.value());
+ }
+ else
+ {
+ *Data->PayloadString = Data->PayloadString->substr(0, Start.value());
+ }
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ }
+ }
+ return 0;
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(ResumeH,
+ CURLOPT_HEADERFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb));
+ curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData);
+ curl_easy_setopt(ResumeH,
+ CURLOPT_WRITEFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData);
+
+ Res = ResumeSess.Perform();
+ Res.Headers = std::move(ResumeHeaders);
+
+ ResumeAttempt++;
+ } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Result),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
new file mode 100644
index 000000000..b7fa52e6c
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -0,0 +1,137 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httpclientcommon.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CurlHttpClient : public HttpClientBase
+{
+public:
+ CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction);
+ ~CurlHttpClient();
+
+ // HttpClientBase
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url,
+ CbPackage Package,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+private:
+ struct CurlResult
+ {
+ long StatusCode = 0;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> Headers;
+ double ElapsedSeconds = 0;
+ int64_t UploadedBytes = 0;
+ int64_t DownloadedBytes = 0;
+ CURLcode ErrorCode = CURLE_OK;
+ std::string ErrorMessage;
+ };
+
+ struct Session
+ {
+ Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {}
+ ~Session();
+
+ CURL* Get() const { return Handle; }
+
+ // Takes ownership of the curl_slist and sets it on the handle.
+ // The list is freed automatically when the Session is destroyed.
+ void SetHeaders(curl_slist* Headers);
+
+ // Low-level perform: executes the request and collects status/timing.
+ CurlResult Perform();
+
+ // Sets up standard write+header callbacks, performs the request, and
+ // moves the collected body and headers into the returned CurlResult.
+ CurlResult PerformWithResponseCallbacks();
+
+ LoggerRef Log() { return Outer->Log(); }
+
+ private:
+ CurlHttpClient* Outer;
+ CURL* Handle;
+ curl_slist* HeaderList = nullptr;
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+ };
+
+ Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters);
+
+ RwLock m_SessionLock;
+ std::vector<CURL*> m_Sessions;
+
+ void ReleaseSession(CURL* Handle);
+
+ CurlResult DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+ CurlResult DoWithRetry(
+ std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; });
+
+ bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+
+ static bool ShouldRetry(const CurlResult& Result);
+
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
+
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
+
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..fbae9f5fe
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,641 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ CloseSocket();
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void CloseSocket()
+ {
+ asio::error_code Ec;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ m_UnixSocket->close(Ec);
+ return;
+ }
+#endif
+ if (m_TcpSocket)
+ {
+ m_TcpSocket->close(Ec);
+ }
+ }
+
+ template<typename Fn>
+ void WithSocket(Fn&& Func)
+ {
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ Func(*m_UnixSocket);
+ return;
+ }
+#endif
+ Func(*m_TcpSocket);
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard.emplace(m_IoContext.get_executor());
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ asio::post(m_IoContext, [this] { DoConnectUnix(); });
+ return;
+ }
+#endif
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void DoConnectUnix()
+ {
+ m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath);
+ CloseSocket();
+ }
+ });
+
+ asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath));
+ m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+#endif
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ CloseSocket();
+ }
+ });
+
+ asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ WithSocket([this](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ WithSocket([this](auto& Socket) {
+ asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ WithSocket([this](auto& Socket) {
+ asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ WithSocket([this, OwnedFrame](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket;
+#endif
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index d3b59df2b..9f49802a0 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -21,6 +21,8 @@
#include "clients/httpclientcommon.h"
+#include <numeric>
+
#if ZEN_WITH_TESTS
# include <zencore/scopeguard.h>
# include <zencore/testing.h>
@@ -34,9 +36,43 @@
namespace zen {
+#if ZEN_WITH_CPR
extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri,
const HttpClientSettings& ConnectionSettings,
std::function<bool()>&& CheckIfAbortFunction);
+#endif
+
+extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction);
+
+static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCurl;
+
+void
+SetDefaultHttpClientBackend(HttpClientBackend Backend)
+{
+ g_DefaultHttpClientBackend = Backend;
+}
+
+void
+SetDefaultHttpClientBackend(std::string_view Backend)
+{
+#if ZEN_WITH_CPR
+ if (Backend == "cpr")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCpr;
+ }
+ else
+#endif
+ if (Backend == "curl")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCurl;
+ }
+ else
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kDefault;
+ }
+}
using namespace std::literals;
@@ -102,6 +138,109 @@ HttpClientBase::GetAccessToken()
//////////////////////////////////////////////////////////////////////////
+HttpClientError::ResponseClass
+HttpClientError::GetResponseClass() const
+{
+ if (m_Error != HttpClientErrorCode::kOK)
+ {
+ switch (m_Error)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ return ResponseClass::kHttpCantConnectError;
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return ResponseClass::kHttpNoHost;
+ case HttpClientErrorCode::kInternalError:
+ case HttpClientErrorCode::kNetworkReceiveError:
+ case HttpClientErrorCode::kNetworkSendFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ return ResponseClass::kHttpTimeout;
+ case HttpClientErrorCode::kSSLConnectError:
+ case HttpClientErrorCode::kSSLCertificateError:
+ case HttpClientErrorCode::kSSLCACertError:
+ case HttpClientErrorCode::kGenericSSLError:
+ return ResponseClass::kHttpSLLError;
+ default:
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ else if (IsHttpSuccessCode(m_ResponseCode))
+ {
+ return ResponseClass::kSuccess;
+ }
+ else
+ {
+ switch (m_ResponseCode)
+ {
+ case HttpResponseCode::Unauthorized:
+ return ResponseClass::kHttpUnauthorized;
+ case HttpResponseCode::NotFound:
+ return ResponseClass::kHttpNotFound;
+ case HttpResponseCode::Forbidden:
+ return ResponseClass::kHttpForbidden;
+ case HttpResponseCode::Conflict:
+ return ResponseClass::kHttpConflict;
+ case HttpResponseCode::InternalServerError:
+ return ResponseClass::kHttpInternalServerError;
+ case HttpResponseCode::ServiceUnavailable:
+ return ResponseClass::kHttpServiceUnavailable;
+ case HttpResponseCode::BadGateway:
+ return ResponseClass::kHttpBadGateway;
+ case HttpResponseCode::GatewayTimeout:
+ return ResponseClass::kHttpGatewayTimeout;
+ default:
+ if (m_ResponseCode >= HttpResponseCode::InternalServerError)
+ {
+ return ResponseClass::kHttpOtherServerError;
+ }
+ else
+ {
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::vector<std::pair<uint64_t, uint64_t>>
+HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const
+{
+ if (Ranges.empty())
+ {
+ return {};
+ }
+
+ std::vector<std::pair<uint64_t, uint64_t>> Result;
+ Result.reserve(OffsetAndLengthPairs.size());
+
+ auto BoundaryIt = Ranges.begin();
+ auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin();
+ while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end())
+ {
+ uint64_t Offset = OffsetAndLengthPairIt->first;
+ uint64_t Length = OffsetAndLengthPairIt->second;
+ while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength)
+ {
+ BoundaryIt++;
+ if (BoundaryIt == Ranges.end())
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ }
+ if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset)
+ {
+ throw std::runtime_error("HttpClient::Response can not fulfill requested range");
+ }
+ uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset;
+ uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange;
+ Result.emplace_back(std::make_pair(RangePayloadOffset, Length));
+
+ OffsetAndLengthPairIt++;
+ }
+ return Result;
+}
+
CbObject
HttpClient::Response::AsObject() const
{
@@ -182,7 +321,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const
{
if (Error.has_value())
{
- return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage);
+ return fmt::format("{}{}HTTP error ({}) '{}'",
+ Prefix,
+ Prefix.empty() ? ""sv : ": "sv,
+ static_cast<int>(Error->ErrorCode),
+ Error->ErrorMessage);
}
else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode)
{
@@ -205,19 +348,36 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix)
{
if (!IsSuccess())
{
- throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode);
+ throw HttpClientError(ErrorMessage(ErrorPrefix),
+ Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK,
+ StatusCode);
}
}
//////////////////////////////////////////////////////////////////////////
HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
-: m_BaseUri(BaseUri)
+: m_Log(zen::logging::Get(ConnectionSettings.LogCategory))
+, m_BaseUri(BaseUri)
, m_ConnectionSettings(ConnectionSettings)
{
m_SessionId = GetSessionIdString();
- m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ HttpClientBackend EffectiveBackend =
+ ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend;
+
+ switch (EffectiveBackend)
+ {
+#if ZEN_WITH_CPR
+ case HttpClientBackend::kCpr:
+ m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+#endif
+ case HttpClientBackend::kCurl:
+ default:
+ m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+ }
}
HttpClient::~HttpClient()
@@ -287,9 +447,12 @@ HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType C
}
HttpClient::Response
-HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url,
+ CbObject Payload,
+ const HttpClient::KeyValueMap& AdditionalHeader,
+ const std::filesystem::path& TempFolderPath)
{
- return m_Inner->Post(Url, Payload, AdditionalHeader);
+ return m_Inner->Post(Url, Payload, AdditionalHeader, TempFolderPath);
}
HttpClient::Response
@@ -340,10 +503,55 @@ HttpClient::Authenticate()
return m_Inner->Authenticate();
}
+LatencyTestResult
+MeasureLatency(HttpClient& Client, std::string_view Url)
+{
+ std::vector<double> MeasurementTimes;
+ std::string ErrorMessage;
+
+ for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++)
+ {
+ HttpClient::Response MeasureResponse = Client.Get(Url);
+ if (MeasureResponse.IsSuccess())
+ {
+ MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds);
+ Sleep(5);
+ }
+ else
+ {
+ ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url));
+
+ // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable.
+ // Bail out immediately — retrying will just burn the connect timeout each time.
+ if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError())
+ {
+ break;
+ }
+ }
+ }
+
+ if (MeasurementTimes.empty())
+ {
+ return {.Success = false, .FailureReason = ErrorMessage};
+ }
+
+ if (MeasurementTimes.size() > 2)
+ {
+ std::sort(MeasurementTimes.begin(), MeasurementTimes.end());
+ MeasurementTimes.pop_back(); // Remove the worst time
+ }
+
+ double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size();
+
+ return {.Success = true, .LatencySeconds = AverageLatency};
+}
+
//////////////////////////////////////////////////////////////////////////
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpclient");
+
TEST_CASE("responseformat")
{
using namespace std::literals;
@@ -753,6 +961,8 @@ TEST_CASE("httpclient.password")
AsioServer->RequestExit();
}
}
+TEST_SUITE_END();
+
void
httpclient_forcelink()
{
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
new file mode 100644
index 000000000..5f3ad2455
--- /dev/null
+++ b/src/zenhttp/httpclient_test.cpp
@@ -0,0 +1,1701 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpserver.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryutil.h>
+# include <zencore/compositebuffer.h>
+# include <zencore/filesystem.h>
+# include <zencore/iobuffer.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/session.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include "servers/httpasio.h"
+
+# include <atomic>
+# include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+// Test service
+
+class HttpClientTestService : public HttpService
+{
+public:
+ HttpClientTestService()
+ {
+ m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool {
+ for (char C : Str)
+ {
+ if (C < '0' || C > '9')
+ {
+ return false;
+ }
+ }
+ return !Str.empty();
+ });
+
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ IoBuffer Body = HttpReq.ReadPayload();
+ HttpContentType CT = HttpReq.RequestContentType();
+ HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body);
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "echo/headers",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ CbObjectWriter Writer;
+ if (!Auth.empty())
+ {
+ Writer.AddString("Authorization", Auth);
+ }
+ HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "echo/method",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Method = ToString(HttpReq.RequestVerb());
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "json",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Obj.AddString("message", "test");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "nocontent",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "created",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created");
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "content-type/text",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/json",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}");
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/binary",
+ [](HttpRouterRequest& Req) {
+ uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF};
+ IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data));
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "content-type/cbobject",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddString("type", "cbobject");
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "auth/bearer",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Auth = HttpReq.GetAuthorizationHeader();
+ if (Auth.starts_with("Bearer ") && Auth.size() > 7)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated");
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized");
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "slow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Sleep(2000);
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response");
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "large",
+ [](HttpRouterRequest& Req) {
+ constexpr size_t Size = 64 * 1024;
+ IoBuffer Buf(Size);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < Size; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "status/{statuscode}",
+ [](HttpRouterRequest& Req) {
+ std::string_view CodeStr = Req.GetCapture(1);
+ int Code = std::stoi(std::string{CodeStr});
+ const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code);
+ Req.ServerRequest().WriteResponse(ResponseCode);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "attempt-counter",
+ [this](HttpRouterRequest& Req) {
+ uint32_t Count = m_AttemptCounter.fetch_add(1);
+ if (Count < m_FailCount)
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable);
+ }
+ else
+ {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries");
+ }
+ },
+ HttpVerb::kGet);
+ }
+
+ virtual const char* BaseUri() const override { return "/api/test/"; }
+ virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); }
+
+ void ResetAttemptCounter(uint32_t FailCount)
+ {
+ m_AttemptCounter.store(0);
+ m_FailCount = FailCount;
+ }
+
+private:
+ HttpRequestRouter m_Router;
+ std::atomic<uint32_t> m_AttemptCounter{0};
+ uint32_t m_FailCount = 2;
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Test server fixture
+
+struct TestServerFixture
+{
+ HttpClientTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ int Port = -1;
+
+ TestServerFixture()
+ {
+ Server = CreateHttpAsioServer(AsioConfig{});
+ Port = Server->Initialize(0, TmpDir.Path());
+ ZEN_ASSERT(Port != -1);
+ Server->RegisterService(TestService);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~TestServerFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+TEST_SUITE_BEGIN("http.httpclient");
+
+TEST_CASE("httpclient.verbs")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("GET returns 200 with expected body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "GET");
+ }
+
+ SUBCASE("POST dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "POST");
+ }
+
+ SUBCASE("PUT dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "PUT");
+ }
+
+ SUBCASE("DELETE dispatches correctly")
+ {
+ HttpClient::Response Resp = Client.Delete("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "DELETE");
+ }
+
+ SUBCASE("HEAD returns 200 with empty body")
+ {
+ HttpClient::Response Resp = Client.Head("/api/test/echo/method");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), ""sv);
+ }
+}
+
+TEST_CASE("httpclient.get")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("simple GET with text response")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("GET with auth header via echo")
+ {
+ HttpClient::Response Resp =
+ Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123"));
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123");
+ }
+
+ SUBCASE("GET returning CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("GET large payload")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+
+ const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData());
+ bool Valid = true;
+ for (size_t i = 0; i < 64 * 1024; ++i)
+ {
+ if (Data[i] != static_cast<uint8_t>(i & 0xFF))
+ {
+ Valid = false;
+ break;
+ }
+ }
+ CHECK(Valid);
+ }
+}
+
+TEST_CASE("httpclient.post")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("POST with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "test payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "test payload data");
+ }
+
+ SUBCASE("POST with IoBuffer and explicit content type")
+ {
+ const char* Payload = "{\"key\":\"value\"}";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}");
+ }
+
+ SUBCASE("POST with CbObject payload round-trip")
+ {
+ CbObjectWriter Writer;
+ Writer.AddBool("enabled", true);
+ Writer.AddString("name", "testobj");
+ CbObject Obj = Writer.Save();
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Obj);
+ CHECK(Resp.IsSuccess());
+ CbObject RoundTripped = Resp.AsObject();
+ CHECK(RoundTripped["enabled"].AsBool() == true);
+ CHECK_EQ(RoundTripped["name"].AsString(), "testobj");
+ }
+
+ SUBCASE("POST with CompositeBuffer payload")
+ {
+ const char* Part1 = "hello ";
+ const char* Part2 = "composite";
+ IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1));
+ IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2));
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello composite");
+ }
+
+ SUBCASE("POST with custom headers")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{});
+ CHECK(Resp.IsSuccess());
+ }
+
+ SUBCASE("POST with empty body to nocontent endpoint")
+ {
+ HttpClient::Response Resp = Client.Post("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+}
+
+TEST_CASE("httpclient.put")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("PUT with IoBuffer payload echo round-trip")
+ {
+ const char* Payload = "put payload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "put payload data");
+ }
+
+ SUBCASE("PUT with parameters only")
+ {
+ HttpClient::Response Resp = Client.Put("/api/test/nocontent");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+
+ SUBCASE("PUT to created endpoint")
+ {
+ const char* Payload = "new resource";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Put("/api/test/created", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Resp.AsText(), "resource created");
+ }
+}
+
+TEST_CASE("httpclient.upload")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("Upload IoBuffer")
+ {
+ constexpr size_t Size = 128 * 1024;
+ IoBuffer Blob = CreateSemiRandomBlob(Size);
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), Size);
+ }
+
+ SUBCASE("Upload CompositeBuffer")
+ {
+ IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024);
+ IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024);
+
+ SharedBuffer Seg1{Buf1};
+ SharedBuffer Seg2{Buf2};
+ CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)};
+
+ HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.download")
+{
+ TestServerFixture Fixture;
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("Download small payload stays in memory")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill")
+ {
+ HttpClientSettings Settings;
+ Settings.MaximumInMemoryDownloadSize = 4;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+TEST_CASE("httpclient.post-streaming")
+{
+ TestServerFixture Fixture;
+ ScopedTemporaryDirectory PostDir;
+
+ SUBCASE("POST CbObject with TempFolderPath stays in memory when response is small")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ CbObjectWriter Writer;
+ Writer.AddBool("streaming", false);
+ Writer.AddString("mode", "memory");
+ CbObject Obj = Writer.Save();
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path());
+ CHECK(Resp.IsSuccess());
+ IoBufferFileReference _;
+ CHECK(!Resp.ResponsePayload.GetFileReference(_));
+ CbObject RoundTripped = Resp.AsObject();
+ CHECK(RoundTripped["streaming"].AsBool() == false);
+ CHECK_EQ(RoundTripped["mode"].AsString(), "memory");
+ }
+
+ SUBCASE("POST CbObject with TempFolderPath streams to file when response exceeds MaximumInMemoryDownloadSize")
+ {
+ HttpClientSettings Settings;
+ Settings.MaximumInMemoryDownloadSize = 4;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ CbObjectWriter Writer;
+ Writer.AddBool("streaming", true);
+ Writer.AddString("mode", "file");
+ CbObject Obj = Writer.Save();
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path());
+ CHECK(Resp.IsSuccess());
+ IoBufferFileReference _;
+ CHECK(Resp.ResponsePayload.GetFileReference(_));
+ CbObject RoundTripped = Resp.AsObject();
+ CHECK(RoundTripped["streaming"].AsBool() == true);
+ CHECK_EQ(RoundTripped["mode"].AsString(), "file");
+ }
+}
+
+TEST_CASE("httpclient.status-codes")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("2xx are success")
+ {
+ CHECK(Client.Get("/api/test/status/200").IsSuccess());
+ CHECK(Client.Get("/api/test/status/201").IsSuccess());
+ CHECK(Client.Get("/api/test/status/204").IsSuccess());
+ }
+
+ SUBCASE("4xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/400").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/401").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/403").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/404").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/409").IsSuccess());
+ }
+
+ SUBCASE("5xx are not success")
+ {
+ CHECK(!Client.Get("/api/test/status/500").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/502").IsSuccess());
+ CHECK(!Client.Get("/api/test/status/503").IsSuccess());
+ }
+
+ SUBCASE("status code values match")
+ {
+ CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created);
+ CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent);
+ CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest);
+ CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized);
+ CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden);
+ CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound);
+ CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict);
+ CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError);
+ CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway);
+ CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.response")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("IsSuccess and operator bool for success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(static_cast<bool>(Resp));
+ }
+
+ SUBCASE("IsSuccess and operator bool for failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ CHECK(!Resp.IsSuccess());
+ CHECK(!static_cast<bool>(Resp));
+ }
+
+ SUBCASE("AsText returns body")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("AsText returns empty for no-content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/nocontent");
+ CHECK(Resp.AsText().empty());
+ }
+
+ SUBCASE("AsObject parses CbObject")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CbObject Obj = Resp.AsObject();
+ CHECK(Obj["ok"].AsBool() == true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("AsObject returns empty for non-CB content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CbObject Obj = Resp.AsObject();
+ CHECK(!Obj);
+ }
+
+ SUBCASE("ToText for text content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK_EQ(Resp.ToText(), "plain text");
+ }
+
+ SUBCASE("ToText for CbObject content")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ std::string Text = Resp.ToText();
+ CHECK(!Text.empty());
+ // ToText for CbObject converts to JSON string representation
+ CHECK(Text.find("ok") != std::string::npos);
+ CHECK(Text.find("test") != std::string::npos);
+ }
+
+ SUBCASE("ErrorMessage includes status code on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/404");
+ std::string Msg = Resp.ErrorMessage("test-prefix");
+ CHECK(Msg.find("test-prefix") != std::string::npos);
+ CHECK(Msg.find("404") != std::string::npos);
+ }
+
+ SUBCASE("ThrowError throws on failure")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/500");
+ CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError);
+ }
+
+ SUBCASE("ThrowError does not throw on success")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK_NOTHROW(Resp.ThrowError("test"));
+ }
+
+ SUBCASE("HttpClientError carries response code")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/status/403");
+ try
+ {
+ Resp.ThrowError("test");
+ CHECK(false); // should not reach
+ }
+ catch (const HttpClientError& Err)
+ {
+ CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden);
+ }
+ }
+}
+
+TEST_CASE("httpclient.error-handling")
+{
+ SUBCASE("Connection refused")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("Request timeout")
+ {
+ TestServerFixture Fixture;
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/slow");
+ CHECK(!Resp.IsSuccess());
+ }
+
+ SUBCASE("Nonexistent endpoint returns failure")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/does-not-exist");
+ CHECK(!Resp.IsSuccess());
+ }
+}
+
+TEST_CASE("httpclient.session")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Default session ID is non-empty")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.GetSessionId().empty());
+ }
+
+ SUBCASE("SetSessionId changes ID")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ std::string OldId = std::string(Client.GetSessionId());
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ CHECK_NE(Client.GetSessionId(), OldId);
+ }
+
+ SUBCASE("SetSessionId with Zero resets")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ Oid NewId = Oid::NewOid();
+ Client.SetSessionId(NewId);
+ CHECK_EQ(Client.GetSessionId(), NewId.ToString());
+ Client.SetSessionId(Oid::Zero);
+ // After resetting, should get a session string (not empty, not the custom one)
+ CHECK(!Client.GetSessionId().empty());
+ CHECK_NE(Client.GetSessionId(), NewId.ToString());
+ }
+}
+
+TEST_CASE("httpclient.authentication")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Authenticate returns false without provider")
+ {
+ HttpClient Client = Fixture.MakeClient();
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns true with valid token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "valid-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(Client.Authenticate());
+ }
+
+ SUBCASE("Authenticate returns false with expired token")
+ {
+ HttpClientSettings Settings;
+ Settings.AccessTokenProvider = []() -> HttpClientAccessToken {
+ return HttpClientAccessToken{
+ .Value = "expired-token",
+ .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1),
+ };
+ };
+ HttpClient Client = Fixture.MakeClient(Settings);
+ CHECK(!Client.Authenticate());
+ }
+
+ SUBCASE("Bearer token verified by auth endpoint")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response AuthResp =
+ Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token"));
+ CHECK(AuthResp.IsSuccess());
+ CHECK_EQ(AuthResp.AsText(), "authenticated");
+ }
+
+ SUBCASE("Request without token to auth endpoint gets 401")
+ {
+ HttpClient Client = Fixture.MakeClient();
+
+ HttpClient::Response Resp = Client.Get("/api/test/auth/bearer");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized);
+ }
+}
+
+TEST_CASE("httpclient.content-types")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("text content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/text");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText);
+ }
+
+ SUBCASE("JSON content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/json");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON);
+ }
+
+ SUBCASE("binary content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/binary");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary);
+ }
+
+ SUBCASE("CbObject content type")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject);
+ }
+}
+
+TEST_CASE("httpclient.metadata")
+{
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("ElapsedSeconds is positive")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.ElapsedSeconds > 0.0);
+ }
+
+ SUBCASE("DownloadedBytes populated for GET")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.DownloadedBytes > 0);
+ }
+
+ SUBCASE("UploadedBytes populated for POST with payload")
+ {
+ const char* Payload = "some upload data";
+ IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload));
+ Buf.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Buf);
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.UploadedBytes > 0);
+ }
+}
+
+TEST_CASE("httpclient.retry")
+{
+ TestServerFixture Fixture;
+
+ SUBCASE("Retry succeeds after transient failures")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "success after retries");
+ }
+
+ SUBCASE("No retry returns 503 immediately")
+ {
+ Fixture.TestService.ResetAttemptCounter(2);
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 0;
+ HttpClient Client = Fixture.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/api/test/attempt-counter");
+ CHECK(!Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable);
+ }
+}
+
+TEST_CASE("httpclient.measurelatency")
+{
+ SUBCASE("Successful measurement against live server")
+ {
+ TestServerFixture Fixture;
+ HttpClient Client = Fixture.MakeClient();
+
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(Result.Success);
+ CHECK(Result.LatencySeconds > 0.0);
+ }
+
+ SUBCASE("Failed measurement against unreachable port")
+ {
+ HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
+ CHECK(!Result.Success);
+ CHECK(!Result.FailureReason.empty());
+ }
+}
+
+TEST_CASE("httpclient.keyvaluemap")
+{
+ SUBCASE("Default construction is empty")
+ {
+ HttpClient::KeyValueMap Map;
+ CHECK(Map->empty());
+ }
+
+ SUBCASE("Construction from pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value"));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from string_view pair")
+ {
+ HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv));
+ CHECK_EQ(Map->size(), 1u);
+ CHECK_EQ(Map->at("key"), "value");
+ }
+
+ SUBCASE("Construction from initializer list")
+ {
+ HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}});
+ CHECK_EQ(Map->size(), 2u);
+ CHECK_EQ(Map->at("a"), "1");
+ CHECK_EQ(Map->at("b"), "2");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Transport fault testing
+
+static std::string
+MakeRawHttpResponse(int StatusCode, std::string_view Body)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: text/plain\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ StatusCode,
+ Body.size(),
+ Body);
+}
+
+static std::string
+MakeRawHttpHeaders(int StatusCode, size_t ContentLength)
+{
+ return fmt::format(
+ "HTTP/1.1 {} OK\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n",
+ StatusCode,
+ ContentLength);
+}
+
+static void
+DrainHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+}
+
+static void
+DrainFullHttpRequest(asio::ip::tcp::socket& Socket)
+{
+ // Read until end of headers
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Extract headers to find Content-Length
+ std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+
+ size_t ContentLength = 0;
+ auto Pos = Headers.find("Content-Length: ");
+ if (Pos == std::string::npos)
+ {
+ Pos = Headers.find("content-length: ");
+ }
+ if (Pos != std::string::npos)
+ {
+ size_t ValStart = Pos + 16; // length of "Content-Length: "
+ size_t ValEnd = Headers.find("\r\n", ValStart);
+ if (ValEnd != std::string::npos)
+ {
+ ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart));
+ }
+ }
+
+ // Calculate how many body bytes were already read past the header boundary.
+ // asio::read_until may read past the delimiter, so Buf.data() contains everything read.
+ size_t HeaderEnd = Headers.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0;
+ size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0;
+
+ if (Remaining > 0)
+ {
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+static void
+DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead)
+{
+ // Read headers first
+ asio::streambuf Buf;
+ std::error_code Ec;
+ asio::read_until(Socket, Buf, "\r\n\r\n", Ec);
+ if (Ec)
+ {
+ return;
+ }
+
+ // Determine how many body bytes were already buffered past headers
+ std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data()));
+ size_t HeaderEnd = All.find("\r\n\r\n") + 4;
+ size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0;
+
+ if (BodyBytesInBuf < BytesToRead)
+ {
+ size_t Remaining = BytesToRead - BodyBytesInBuf;
+ std::vector<char> BodyBuf(Remaining);
+ asio::read(Socket, asio::buffer(BodyBuf), Ec);
+ }
+}
+
+struct FaultTcpServer
+{
+ using FaultHandler = std::function<void(asio::ip::tcp::socket&)>;
+
+ asio::io_context m_IoContext;
+ asio::ip::tcp::acceptor m_Acceptor;
+ FaultHandler m_Handler;
+ std::thread m_Thread;
+ int m_Port;
+
+ explicit FaultTcpServer(FaultHandler Handler)
+ : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0))
+ , m_Handler(std::move(Handler))
+ {
+ m_Port = m_Acceptor.local_endpoint().port();
+ StartAccept();
+ m_Thread = std::thread([this]() {
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (...)
+ {
+ }
+ });
+ }
+
+ ~FaultTcpServer()
+ {
+ // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this
+ // thread — ASIO I/O objects are not safe for concurrent access and the io_context
+ // thread may be touching the acceptor in StartAccept().
+ m_IoContext.stop();
+ if (m_Thread.joinable())
+ {
+ m_Thread.join();
+ }
+ }
+
+ FaultTcpServer(const FaultTcpServer&) = delete;
+ FaultTcpServer& operator=(const FaultTcpServer&) = delete;
+
+ void StartAccept()
+ {
+ m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) {
+ if (!Ec)
+ {
+ m_Handler(Socket);
+ }
+ if (m_Acceptor.is_open())
+ {
+ StartAccept();
+ }
+ });
+ }
+
+ HttpClient MakeClient(HttpClientSettings Settings = {})
+ {
+ return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {});
+ }
+};
+
+TEST_CASE("httpclient.range-response")
+{
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("single range 206 response populates Ranges")
+ {
+ std::string RangeBody(100, 'A');
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 200-299/1000\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ RangeBody.size(),
+ RangeBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 1);
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 200);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 100);
+ }
+
+ SUBCASE("multipart byteranges 206 response populates Ranges")
+ {
+ std::string Part1Data(16, 'X');
+ std::string Part2Data(12, 'Y');
+ std::string Boundary = "testboundary123";
+
+ std::string MultipartBody = fmt::format(
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 100-115/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 500-511/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}--",
+ Boundary,
+ Part1Data,
+ Boundary,
+ Part2Data,
+ Boundary);
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: multipart/byteranges; boundary={}\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ Boundary,
+ MultipartBody.size(),
+ MultipartBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 2);
+ // Ranges should be sorted by RangeOffset
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 100);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 16);
+ CHECK_EQ(Resp.Ranges[1].RangeOffset, 500);
+ CHECK_EQ(Resp.Ranges[1].RangeLength, 12);
+ }
+
+ SUBCASE("non-range 200 response has empty Ranges")
+ {
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "full content");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.Ranges.empty());
+ }
+}
+
+TEST_CASE("httpclient.transport-faults" * doctest::skip())
+{
+ SUBCASE("connection reset before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection closed before response")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("partial headers then close")
+ {
+ // libcurl parses the status line (200 OK) and accepts the response even though
+ // headers are truncated mid-field. It reports success with an empty body instead
+ // of an error. Ideally this should be detected as a transport failure.
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Partial = "HTTP/1.1 200 OK\r\nContent-";
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Partial), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ WARN(!Resp.IsSuccess());
+ WARN(Resp.Error.has_value());
+ }
+
+ SUBCASE("truncated body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(100, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("connection reset mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 10000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ std::string PartialBody(1000, 'x');
+ asio::write(Socket, asio::buffer(PartialBody), Ec);
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("stalled response triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Headers = MakeRawHttpHeaders(200, 1000);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Headers), Ec);
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(500);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("retry succeeds after transient failures")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ DrainHttpRequest(Socket);
+ if (N < 2)
+ {
+ // Connection reset produces NETWORK_SEND_FAILURE which is retryable
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ std::string Response = MakeRawHttpResponse(200, "recovered");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ HttpClient::Response Resp = Client.Get("/test");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "recovered");
+ }
+}
+
+TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
+{
+ constexpr size_t kPostBodySize = 256 * 1024;
+
+ auto MakePostBody = []() -> IoBuffer {
+ IoBuffer Buf(kPostBodySize);
+ uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData());
+ for (size_t i = 0; i < kPostBodySize; ++i)
+ {
+ Ptr[i] = static_cast<uint8_t>(i & 0xFF);
+ }
+ Buf.SetContentType(ZenContentType::kBinary);
+ return Buf;
+ };
+
+ SUBCASE("POST: server resets before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server closes before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: server resets mid-body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainPartialBody(Socket, 8 * 1024);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ }
+
+ SUBCASE("POST: early error response before consuming body")
+ {
+ FaultTcpServer Server([](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(503, "service busy");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
+ Socket.close(Ec);
+ });
+ HttpClient Client = Server.MakeClient();
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ // With a large upload body, the server may RST the connection before the client
+ // reads the 503 response. Either outcome is valid: the client sees the HTTP 503
+ // status, or it sees a transport-level error from the RST.
+ CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value()));
+ }
+
+ SUBCASE("POST: stalled upload triggers timeout")
+ {
+ std::atomic<bool> StallActive{true};
+ FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ // Stop reading body — TCP window will fill and client send will stall
+ while (StallActive.load())
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.Timeout = std::chrono::milliseconds(2000);
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(!Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ StallActive.store(false);
+ }
+
+ SUBCASE("POST: retry with large body after transient failure")
+ {
+ std::atomic<int> ConnCount{0};
+ FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) {
+ int N = ConnCount.fetch_add(1);
+ if (N < 2)
+ {
+ DrainHttpRequest(Socket);
+ std::error_code Ec;
+ Socket.set_option(asio::socket_base::linger(true, 0), Ec);
+ Socket.close(Ec);
+ }
+ else
+ {
+ DrainFullHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "upload-ok");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ }
+ });
+
+ HttpClientSettings Settings;
+ Settings.RetryCount = 3;
+ HttpClient Client = Server.MakeClient(Settings);
+
+ IoBuffer Body = MakePostBody();
+ HttpClient::Response Resp = Client.Post("/test", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "upload-ok");
+ }
+}
+
+TEST_CASE("httpclient.unixsocket")
+{
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "zen.sock").string();
+
+ HttpClientTestService TestService;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ HttpClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpClient Client("localhost", Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over unix socket")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over unix socket")
+ {
+ const char* Payload = "unix socket payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "unix socket payload");
+ }
+}
+
+# if ZEN_USE_OPENSSL
+
+TEST_CASE("httpclient.https")
+{
+ // Self-signed test certificate for localhost/127.0.0.1, valid until 2036
+ static constexpr std::string_view TestCertPem =
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIIDJTCCAg2gAwIBAgIUEtJYMSUmJmvJ157We/qXNVJ7W8gwDQYJKoZIhvcNAQEL\n"
+ "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMwOTIwMjU1M1oXDTM2MDMw\n"
+ "NjIwMjU1M1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\n"
+ "AAOCAQ8AMIIBCgKCAQEAv9YvZ6WeBz3z/Zuxi6OIivWksDxDZZ5oAXKVwlUXaa7v\n"
+ "iDkm9P5ZsEhN+M5vZMe2Yb9i3cnTUaE6Avs1ddOwTAYNGrE/B5DmibrRWc23R0cv\n"
+ "gdnYQJ+gjsAeMvUWYLK58xW4YoMR5bmfpj1ruqobUNkG/oJYnAUcjgo4J149irW+\n"
+ "4n9uLJvxL+5fI/b/AIkv+4TMe70/d/BPmnixWrrzxUT6S5ghE2Mq7+XLScfpY2Sp\n"
+ "GQ/Xbnj9/ELYLpQnNLuVZwWZDpXj+FLbF1zxgjYdw1cCjbRcOIEW2/GJeJvGXQ6Y\n"
+ "Vld5pCBm9uKPPLWoFCoakK5YvP00h+8X+HghGVSscQIDAQABo28wbTAdBgNVHQ4E\n"
+ "FgQUgM6hjymi6g2EBUg2ENu0nIK8yhMwHwYDVR0jBBgwFoAUgM6hjymi6g2EBUg2\n"
+ "ENu0nIK8yhMwDwYDVR0TAQH/BAUwAwEB/zAaBgNVHREEEzARhwR/AAABgglsb2Nh\n"
+ "bGhvc3QwDQYJKoZIhvcNAQELBQADggEBABY1oaaWwL4RaK/epKvk/IrmVT2mlAai\n"
+ "uvGLfjhc6FGvXaxPGTSUPrVbFornaWZAg7bOWCexWnEm2sWd75V/usvZAPN4aIiD\n"
+ "H66YQipq3OD4F9Gowp01IU4AcGh7MerFpYPk76+wp2ANq71x8axtlZjVn3hSFMmN\n"
+ "i6m9S/eyCl9WjYBT5ZEC4fJV0nOSmNe/+gCAm11/js9zNfXKmUchJtuZpubY3A0k\n"
+ "X2II6qYWf1PH+JJkefNZtt2c66CrEN5eAg4/rGEgsp43zcd4ZHVkpBKFLDEls1ev\n"
+ "drQ45zc4Ht77pHfnHu7YsLcRZ9Wq3COMNZYx5lItqnomX2qBm1pkwjI=\n"
+ "-----END CERTIFICATE-----\n";
+
+ static constexpr std::string_view TestKeyPem =
+ "-----BEGIN PRIVATE KEY-----\n"
+ "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/1i9npZ4HPfP9\n"
+ "m7GLo4iK9aSwPENlnmgBcpXCVRdpru+IOSb0/lmwSE34zm9kx7Zhv2LdydNRoToC\n"
+ "+zV107BMBg0asT8HkOaJutFZzbdHRy+B2dhAn6COwB4y9RZgsrnzFbhigxHluZ+m\n"
+ "PWu6qhtQ2Qb+glicBRyOCjgnXj2Ktb7if24sm/Ev7l8j9v8AiS/7hMx7vT938E+a\n"
+ "eLFauvPFRPpLmCETYyrv5ctJx+ljZKkZD9dueP38QtgulCc0u5VnBZkOleP4UtsX\n"
+ "XPGCNh3DVwKNtFw4gRbb8Yl4m8ZdDphWV3mkIGb24o88tagUKhqQrli8/TSH7xf4\n"
+ "eCEZVKxxAgMBAAECggEAILd9pDaZqfCF8SWhdQgx3Ekiii/s6qLGaCDLq7XpZUvB\n"
+ "bEEbBMNwNmFOcvV6B/0LfMYwLVUjZhOSGjoPlwXAVmbdy0SZVEgBGVI0LBWqgUyB\n"
+ "rKqjd/oBXvci71vfMiSpE+0LYjmqTryGnspw2gfy2qn4yGUgiZNRmGPjycsHweUL\n"
+ "V3FHm3cf0dyE4sJ0mjVqZzRT/unw2QOCE6FlY7M1XxZL88IWfn6G4lckdJTwoOP5\n"
+ "VPR2J3XbyhvCeXeDRCHKRXojWWR2HovWnDXQc95GRgCd0vYdHuIUM6RXVPZQvy3X\n"
+ "l0GwQKHNcVr1uwtYDgGKw0tNCUDvxdfQaWilTFuicQKBgQDvEYp+vL1hnF+AVdu3\n"
+ "elsYsHpFgExkTI8wnUMvGZrFiIQyCyVDU3jkG3kcKacI1bfwopXopaQCjrYk9epm\n"
+ "liOVm3/Xtr6e2ENa7w8TQbdK65PciQNOMxml6g8clRRBl0cwj+aI3nW/Kop1cdrR\n"
+ "A9Vo+8iPTO5gDcxTiIb45a6E3QKBgQDNbE009P6ewx9PU7Llkhb9VBgsb7oQN3EV\n"
+ "TCYd4taiN6FPnTuL/cdijAA8y04hiVT+Efo9TUN9NCl9HdHXQcjj7/n/eFLH0Pkw\n"
+ "OIK3QN49OfR88wivLMtwWxIog0tJjc9+7dR4bR4o1jTlIrasEIvUTuDJQ8MKGc9v\n"
+ "pBITua+SpQKBgE4raSKZqj7hd6Sp7kbnHiRLiB9znQbqtaNKuK4M7DuMsNUAKfYC\n"
+ "tDO5+/bGc9SCtTtcnjHM/3zKlyossrFKhGYlyz6IhXnA8v0nz8EXKsy3jMh+kHMg\n"
+ "aFGE394TrOTphyCM3O+B9fRE/7L5QHg5ja1fLqwUlpkXyejCaoe16kONAoGAYIz9\n"
+ "wN1B67cEOVG6rOI8QfdLoV8mEcctNHhlFfjvLrF89SGOwl6WX0A0QF7CK0sUEpK6\n"
+ "jiOJjAh/U5o3bbgyxsedNjEEn3weE0cMUTuA+UALJMtKEqO4PuffIgGL2ld35k28\n"
+ "ZpnK6iC8HdJyD297eV9VkeNygYXeFLgF8xV8ay0CgYEAh4fmVZt9YhgVByYny2kF\n"
+ "ZUIkGF5h9wxzVOPpQwpizIGFFb3i/ZdGQcuLTfIBVRKf50sT3IwJe65ATv6+Lz0f\n"
+ "wg/pMvosi0/F5KGbVRVdzBMQy58WyyGti4tNl+8EXGvo8+DCmjlTYwfjRoZGg/qJ\n"
+ "EMP3/hTN7dHDRxPK8E0Fh0Y=\n"
+ "-----END PRIVATE KEY-----\n";
+
+ ScopedTemporaryDirectory TmpDir;
+
+ // Write cert and key to temp files
+ const auto CertPath = TmpDir.Path() / "test.crt";
+ const auto KeyPath = TmpDir.Path() / "test.key";
+ WriteFile(CertPath, IoBuffer(IoBuffer::Clone, TestCertPem.data(), TestCertPem.size()));
+ WriteFile(KeyPath, IoBuffer(IoBuffer::Clone, TestKeyPem.data(), TestKeyPem.size()));
+
+ HttpClientTestService TestService;
+
+ AsioConfig Config;
+ Config.CertFile = CertPath.string();
+ Config.KeyFile = KeyPath.string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(Config);
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ int HttpsPort = Server->GetEffectiveHttpsPort();
+ REQUIRE(HttpsPort > 0);
+
+ HttpClientSettings Settings;
+ Settings.InsecureSsl = true;
+
+ HttpClient Client(fmt::format("https://127.0.0.1:{}", HttpsPort), Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over HTTPS")
+ {
+ const char* Payload = "https payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "https payload");
+ }
+
+ SUBCASE("GET JSON over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["ok"].AsBool(), true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("Large payload over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+# endif // ZEN_USE_OPENSSL
+
+TEST_SUITE_END();
+
+void
+httpclient_test_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index 761665c30..69000dd8e 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -2,6 +2,8 @@
#include <zenhttp/httpserver.h>
+#include <zencore/filesystem.h>
+
#include "servers/httpasio.h"
#include "servers/httpmulti.h"
#include "servers/httpnull.h"
@@ -23,10 +25,12 @@
#include <zencore/logging.h>
#include <zencore/stream.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zenhttp/packageformat.h>
#include <zentelemetry/otlptrace.h>
+#include <zentelemetry/stats.h>
#include <charconv>
#include <mutex>
@@ -745,6 +749,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand
{
if (UriPattern[i] == '}')
{
+ if (i == PatternStart)
+ {
+ throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern));
+ }
std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart);
if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end())
{
@@ -910,8 +918,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
CapturedSegments.emplace_back(Uri);
- for (int MatcherIndex : Matchers)
+ for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++)
{
+ int MatcherIndex = Matchers[MatcherOffset];
if (UriPos >= UriLen)
{
IsMatch = false;
@@ -921,9 +930,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (MatcherIndex < 0)
{
// Literal match
- int LitIndex = -MatcherIndex - 1;
- const std::string& LitStr = m_Literals[LitIndex];
- size_t LitLen = LitStr.length();
+ int LitIndex = -MatcherIndex - 1;
+ std::string_view LitStr = m_Literals[LitIndex];
+ size_t LitLen = LitStr.length();
if (Uri.substr(UriPos, LitLen) == LitStr)
{
@@ -939,9 +948,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
{
// Matcher function
size_t SegmentStart = UriPos;
- while (UriPos < UriLen && Uri[UriPos] != '/')
+
+ if (MatcherOffset == (Matchers.size() - 1))
{
- ++UriPos;
+ // Last matcher, use the remaining part of the uri
+ UriPos = UriLen;
+ }
+ else
+ {
+ while (UriPos < UriLen && Uri[UriPos] != '/')
+ {
+ ++UriPos;
+ }
}
std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart);
@@ -1014,7 +1032,31 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
int
HttpServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
- return OnInitialize(BasePort, std::move(DataDir));
+ m_EffectivePort = OnInitialize(BasePort, std::move(DataDir));
+ m_ExternalHost = OnGetExternalHost();
+ return m_EffectivePort;
+}
+
+std::string
+HttpServer::OnGetExternalHost() const
+{
+ return GetMachineName();
+}
+
+std::string
+HttpServer::GetServiceUri(const HttpService* Service) const
+{
+ const char* Scheme = (m_EffectiveHttpsPort > 0) ? "https" : "http";
+ int Port = (m_EffectiveHttpsPort > 0) ? m_EffectiveHttpsPort : m_EffectivePort;
+
+ if (Service)
+ {
+ return fmt::format("{}://{}:{}{}", Scheme, m_ExternalHost, Port, Service->BaseUri());
+ }
+ else
+ {
+ return fmt::format("{}://{}:{}", Scheme, m_ExternalHost, Port);
+ }
}
void
@@ -1058,6 +1100,39 @@ HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
OnSetHttpRequestFilter(RequestFilter);
}
+CbObject
+HttpServer::CollectStats()
+{
+ CbObjectWriter Cbo;
+
+ metrics::EmitSnapshot("requests", m_RequestMeter, Cbo);
+
+ Cbo.BeginObject("bytes");
+ {
+ Cbo << "received" << GetTotalBytesReceived();
+ Cbo << "sent" << GetTotalBytesSent();
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("websockets");
+ {
+ Cbo << "active_connections" << GetActiveWebSocketConnectionCount();
+ Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed);
+ Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed);
+ Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed);
+ Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed);
+ }
+ Cbo.EndObject();
+
+ return Cbo.Save();
+}
+
+void
+HttpServer::HandleStatsRequest(HttpServerRequest& Request)
+{
+ Request.WriteResponse(HttpResponseCode::OK, CollectStats());
+}
+
//////////////////////////////////////////////////////////////////////////
HttpRpcHandler::HttpRpcHandler()
@@ -1082,9 +1157,13 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
if (ServerClass == "asio"sv)
{
ZEN_INFO("using asio HTTP server implementation")
- return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount,
- .ForceLoopback = Config.ForceLoopback,
- .IsDedicatedServer = Config.IsDedicatedServer});
+ return CreateHttpAsioServer(AsioConfig {
+ .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer,
+ .NoNetwork = Config.NoNetwork, .UnixSocketPath = PathToUtf8(Config.UnixSocketPath),
+#if ZEN_USE_OPENSSL
+ .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile,
+#endif
+ });
}
#if ZEN_WITH_HTTPSYS
else if (ServerClass == "httpsys"sv)
@@ -1096,7 +1175,11 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
.IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled,
.IsDedicatedServer = Config.IsDedicatedServer,
.ForceLoopback = Config.ForceLoopback,
- .UseExplicitIoThreadPool = Config.HttpSys.UseExplicitIoThreadPool}));
+ .UseExplicitIoThreadPool = Config.HttpSys.UseExplicitIoThreadPool,
+ .HttpsPort = Config.HttpSys.HttpsPort,
+ .CertThumbprint = Config.HttpSys.CertThumbprint,
+ .CertStoreName = Config.HttpSys.CertStoreName,
+ .HttpsOnly = Config.HttpSys.HttpsOnly}));
}
#endif
else if (ServerClass == "null"sv)
@@ -1301,6 +1384,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.httpserver");
+
TEST_CASE("http.common")
{
using namespace std::literals;
@@ -1406,20 +1491,33 @@ TEST_CASE("http.common")
SUBCASE("router-matcher")
{
- bool HandledA = false;
- bool HandledAA = false;
- bool HandledAB = false;
- bool HandledAandB = false;
+ bool HandledA = false;
+ bool HandledAA = false;
+ bool HandledAB = false;
+ bool HandledAandB = false;
+ bool HandledAandPath = false;
std::vector<std::string> Captures;
auto Reset = [&] {
- HandledA = HandledAA = HandledAB = HandledAandB = false;
+ HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false;
Captures.clear();
};
TestHttpService Service;
HttpRequestRouter r;
- r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; });
- r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; });
+
+ r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; });
+ r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; });
+ static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); });
+
+ r.RegisterRoute(
+ "path/{a}/{path}",
+ [&](auto& Req) {
+ HandledAandPath = true;
+ Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
+ },
+ HttpVerb::kGet);
+
r.RegisterRoute(
"{a}",
[&](auto& Req) {
@@ -1448,7 +1546,6 @@ TEST_CASE("http.common")
Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))};
},
HttpVerb::kGet);
-
{
Reset();
TestHttpServerRequest req{Service, "ab"sv};
@@ -1456,6 +1553,7 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 1);
CHECK_EQ(Captures[0], "ab"sv);
@@ -1468,6 +1566,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(HandledAB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1481,6 +1580,7 @@ TEST_CASE("http.common")
CHECK(!HandledAA);
CHECK(!HandledAB);
CHECK(HandledAandB);
+ CHECK(!HandledAandPath);
REQUIRE_EQ(Captures.size(), 2);
CHECK_EQ(Captures[0], "ab"sv);
CHECK_EQ(Captures[1], "def"sv);
@@ -1493,6 +1593,7 @@ TEST_CASE("http.common")
CHECK(!HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
}
{
@@ -1502,6 +1603,35 @@ TEST_CASE("http.common")
CHECK(HandledA);
CHECK(!HandledAA);
CHECK(!HandledAB);
+ CHECK(!HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 1);
+ CHECK_EQ(Captures[0], "a123"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "simple_path.txt"sv);
+ }
+
+ {
+ Reset();
+ TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv};
+ r.HandleRequest(req);
+ CHECK(!HandledA);
+ CHECK(!HandledAA);
+ CHECK(!HandledAB);
+ CHECK(HandledAandPath);
+ REQUIRE_EQ(Captures.size(), 2);
+ CHECK_EQ(Captures[0], "ab"sv);
+ CHECK_EQ(Captures[1], "directory/and/path.txt"sv);
}
}
@@ -1519,6 +1649,8 @@ TEST_CASE("http.common")
}
}
+TEST_SUITE_END();
+
void
http_forcelink()
{
diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h
index c252a5d99..3cfe652c5 100644
--- a/src/zenhttp/include/zenhttp/cprutils.h
+++ b/src/zenhttp/include/zenhttp/cprutils.h
@@ -2,17 +2,19 @@
#pragma once
-#include <zencore/compactbinary.h>
-#include <zencore/compactbinaryvalidation.h>
-#include <zencore/iobuffer.h>
-#include <zencore/string.h>
-#include <zenhttp/formatters.h>
-#include <zenhttp/httpclient.h>
-#include <zenhttp/httpcommon.h>
+#if ZEN_WITH_CPR
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinaryvalidation.h>
+# include <zencore/iobuffer.h>
+# include <zencore/string.h>
+# include <zenhttp/formatters.h>
+# include <zenhttp/httpclient.h>
+# include <zenhttp/httpcommon.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <cpr/response.h>
-#include <fmt/format.h>
+# include <cpr/response.h>
+# include <fmt/format.h>
ZEN_THIRD_PARTY_INCLUDES_END
template<>
@@ -92,3 +94,5 @@ struct fmt::formatter<cpr::Response>
}
}
};
+
+#endif // ZEN_WITH_CPR
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index addb00cb8..90180391c 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response>
if (Response.IsSuccess())
{
return fmt::format_to(Ctx.out(),
- "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
+ "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}",
ToString(Response.StatusCode),
Response.UploadedBytes,
Response.DownloadedBytes,
@@ -84,7 +84,7 @@ struct fmt::formatter<zen::HttpClient::Response>
return fmt::format_to(Ctx.out(),
"Failed: Elapsed: {}, Reason: ({}) '{}",
NiceResponseTime,
- Response.Error.value().ErrorCode,
+ static_cast<int>(Response.Error.value().ErrorCode),
Response.Error.value().ErrorMessage);
}
else
diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h
index 0270973bf..2d384d1d8 100644
--- a/src/zenhttp/include/zenhttp/httpapiservice.h
+++ b/src/zenhttp/include/zenhttp/httpapiservice.h
@@ -1,4 +1,5 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
#include <zenhttp/httpserver.h>
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 9a9b74d72..e878c900f 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -10,9 +10,11 @@
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <filesystem>
#include <functional>
#include <optional>
#include <unordered_map>
+#include <vector>
namespace zen {
@@ -29,6 +31,36 @@ class CompositeBuffer;
*/
+enum class HttpClientErrorCode : int
+{
+ kOK = 0,
+ kConnectionFailure,
+ kHostResolutionFailure,
+ kProxyResolutionFailure,
+ kInternalError,
+ kNetworkReceiveError,
+ kNetworkSendFailure,
+ kOperationTimedOut,
+ kSSLConnectError,
+ kSSLCertificateError,
+ kSSLCACertError,
+ kGenericSSLError,
+ kRequestCancelled,
+ kOtherError,
+};
+
+enum class HttpClientBackend : uint8_t
+{
+ kDefault,
+#if ZEN_WITH_CPR
+ kCpr,
+#endif
+ kCurl,
+};
+
+void SetDefaultHttpClientBackend(std::string_view Backend);
+void SetDefaultHttpClientBackend(HttpClientBackend Backend);
+
struct HttpClientAccessToken
{
using Clock = std::chrono::system_clock;
@@ -58,6 +90,26 @@ struct HttpClientSettings
Oid SessionId = Oid::Zero;
bool Verbose = false;
uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u;
+ HttpClientBackend Backend = HttpClientBackend::kDefault;
+
+ /// Unix domain socket path. When non-empty, the client connects via this
+ /// socket instead of TCP. BaseUri is still used for the Host header and URL.
+ std::filesystem::path UnixSocketPath;
+
+ /// Disable HTTP keep-alive by closing the connection after each request.
+ /// Useful for testing per-connection overhead.
+ bool ForbidReuseConnection = false;
+
+ /// Skip TLS certificate verification (for testing with self-signed certs).
+ bool InsecureSsl = false;
+
+ /// CA certificate bundle path for TLS verification. When non-empty, overrides
+ /// the system default CA store.
+ std::string CaBundlePath;
+
+ /// HTTP status codes that are expected and should not be logged as warnings.
+ /// 404 is always treated as expected regardless of this list.
+ std::vector<HttpResponseCode> ExpectedErrorCodes;
};
class HttpClientError : public std::runtime_error
@@ -65,22 +117,22 @@ class HttpClientError : public std::runtime_error
public:
using _Mybase = runtime_error;
- HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const std::string& Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const char* Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- inline int GetInternalErrorCode() const { return m_Error; }
- inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
+ inline HttpClientErrorCode GetInternalErrorCode() const { return m_Error; }
+ inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
enum class ResponseClass : std::int8_t
{
@@ -107,24 +159,51 @@ public:
ResponseClass GetResponseClass() const;
private:
- const int m_Error = 0;
- const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
+ const HttpClientErrorCode m_Error = HttpClientErrorCode::kOK;
+ const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
};
class HttpClientBase;
+/** HTTP Client
+ *
+ * This is safe for use on multiple threads simultaneously, as each
+ * instance maintains an internal connection pool and will synchronize
+ * access to it as needed.
+ *
+ * Uses libcurl under the hood. We currently only use HTTP 1.1 features.
+ *
+ */
class HttpClient
{
public:
- HttpClient(std::string_view BaseUri,
- const HttpClientSettings& Connectionsettings = {},
- std::function<bool()>&& CheckIfAbortFunction = {});
+ explicit HttpClient(std::string_view BaseUri,
+ const HttpClientSettings& Connectionsettings = {},
+ std::function<bool()>&& CheckIfAbortFunction = {});
~HttpClient();
+ HttpClient(const HttpClient&) = delete;
+ HttpClient& operator=(const HttpClient&) = delete;
+
struct ErrorContext
{
- int ErrorCode;
- std::string ErrorMessage;
+ HttpClientErrorCode ErrorCode;
+ std::string ErrorMessage;
+
+ /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */
+ bool IsConnectionError() const
+ {
+ switch (ErrorCode)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return true;
+ default:
+ return false;
+ }
+ }
};
struct KeyValueMap
@@ -171,13 +250,29 @@ public:
KeyValueMap Header;
// The number of bytes sent as part of the request
- int64_t UploadedBytes;
+ int64_t UploadedBytes = 0;
// The number of bytes received as part of the response
- int64_t DownloadedBytes;
+ int64_t DownloadedBytes = 0;
// The elapsed time in seconds for the request to execute
- double ElapsedSeconds;
+ double ElapsedSeconds = 0.0;
+
+ struct MultipartBoundary
+ {
+ uint64_t OffsetInPayload = 0;
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType;
+ };
+
+ // Ranges will map out all received ranges, both single and multi-range responses
+ // If no range was requested Ranges will be empty
+ std::vector<MultipartBoundary> Ranges;
+
+ // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges).
+ // If the response was not a partial response, an empty vector will be returned
+ std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const;
// This contains any errors from the HTTP stack. It won't contain information on
// why the server responded with a non-success HTTP status, that may be gleaned
@@ -226,7 +321,10 @@ public:
const IoBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {});
- [[nodiscard]] Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {});
+ [[nodiscard]] Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {});
[[nodiscard]] Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Post(std::string_view Url,
const CompositeBuffer& Payload,
@@ -260,6 +358,16 @@ private:
const HttpClientSettings m_ConnectionSettings;
};
-void httpclient_forcelink(); // internal
+struct LatencyTestResult
+{
+ bool Success = false;
+ std::string FailureReason;
+ double LatencySeconds = -1.0;
+};
+
+LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url);
+
+void httpclient_forcelink(); // internal
+void httpclient_test_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index bc18549c9..8fca35ac5 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept
return IsHttpSuccessCode(int(HttpCode));
}
+[[nodiscard]] inline bool
+IsHttpOk(HttpResponseCode HttpCode) noexcept
+{
+ return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted ||
+ HttpCode == HttpResponseCode::NoContent;
+}
+
std::string_view ToString(HttpResponseCode HttpCode);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 7887beacd..42e5b1628 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -13,6 +13,9 @@
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zentelemetry/stats.h>
+
+#include <filesystem>
#include <functional>
#include <gsl/gsl-lite.hpp>
#include <list>
@@ -103,6 +106,7 @@ public:
virtual bool IsLocalMachineRequest() const = 0;
virtual std::string_view GetAuthorizationHeader() const = 0;
+ virtual std::string_view GetRemoteAddress() const { return {}; }
/** Respond with payload
@@ -202,12 +206,34 @@ private:
int m_UriPrefixLength = 0;
};
+struct IHttpStatsProvider
+{
+ /** Handle an HTTP stats request, writing the response directly.
+ * Implementations may inspect query parameters on the request
+ * to include optional detailed breakdowns.
+ */
+ virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
+
+ /** Return the provider's current stats as a CbObject snapshot.
+ * Used by the WebSocket push thread to broadcast live updates
+ * without requiring an HttpServerRequest. Providers that do
+ * not override this will be skipped in WebSocket broadcasts.
+ */
+ virtual CbObject CollectStats() { return {}; }
+};
+
+struct IHttpStatsService
+{
+ virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+ virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
+};
+
/** HTTP server
*
* Implements the main event loop to service HTTP requests, and handles routing
* requests to the appropriate handler as registered via RegisterService
*/
-class HttpServer : public RefCounted
+class HttpServer : public RefCounted, public IHttpStatsProvider
{
public:
void RegisterService(HttpService& Service);
@@ -219,8 +245,65 @@ public:
void RequestExit();
void Close();
+ /** Returns a canonical http:// URI for the given service, using the external
+ * IP and the port the server is actually listening on. Only valid
+ * after Initialize() has returned successfully.
+ */
+ std::string GetServiceUri(const HttpService* Service) const;
+
+ /** Returns the external host string (IP or hostname) determined during Initialize().
+ * Only valid after Initialize() has returned successfully.
+ */
+ std::string_view GetExternalHost() const { return m_ExternalHost; }
+
+ /** Returns the effective HTTPS port, or 0 if HTTPS is not enabled. Only valid after Initialize(). */
+ int GetEffectiveHttpsPort() const { return m_EffectiveHttpsPort; }
+
+ /** Returns total bytes received and sent across all connections since server start. */
+ virtual uint64_t GetTotalBytesReceived() const { return 0; }
+ virtual uint64_t GetTotalBytesSent() const { return 0; }
+
+ /** Mark that a request has been handled. Called by server implementations. */
+ void MarkRequest() { m_RequestMeter.Mark(); }
+
+ /** Set a default redirect path for root requests */
+ void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; }
+
+ std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; }
+
+ /** Track active WebSocket connections — called by server implementations on upgrade/close. */
+ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); }
+ void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); }
+ uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); }
+
+ /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */
+ void OnWebSocketFrameReceived(uint64_t Bytes)
+ {
+ m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+ void OnWebSocketFrameSent(uint64_t Bytes)
+ {
+ m_WsFramesSent.fetch_add(1, std::memory_order_relaxed);
+ m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed);
+ }
+
+ // IHttpStatsProvider
+ virtual CbObject CollectStats() override;
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+
private:
std::vector<HttpService*> m_KnownServices;
+ int m_EffectivePort = 0;
+ int m_EffectiveHttpsPort = 0;
+ std::string m_ExternalHost;
+ metrics::Meter m_RequestMeter;
+ std::string m_DefaultRedirect;
+ std::atomic<uint64_t> m_ActiveWebSocketConnections{0};
+ std::atomic<uint64_t> m_WsFramesReceived{0};
+ std::atomic<uint64_t> m_WsFramesSent{0};
+ std::atomic<uint64_t> m_WsBytesReceived{0};
+ std::atomic<uint64_t> m_WsBytesSent{0};
virtual void OnRegisterService(HttpService& Service) = 0;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0;
@@ -228,6 +311,10 @@ private:
virtual void OnRun(bool IsInteractiveSession) = 0;
virtual void OnRequestExit() = 0;
virtual void OnClose() = 0;
+
+protected:
+ void SetEffectiveHttpsPort(int Port) { m_EffectiveHttpsPort = Port; }
+ virtual std::string OnGetExternalHost() const;
};
struct HttpServerPluginConfig
@@ -243,6 +330,11 @@ struct HttpServerConfig
std::vector<HttpServerPluginConfig> PluginConfigs;
bool ForceLoopback = false;
unsigned int ThreadCount = 0;
+ std::filesystem::path UnixSocketPath; // Unix domain socket path (empty = disabled)
+ bool NoNetwork = false; // Disable TCP/HTTPS listeners; only accept connections via UnixSocketPath
+ int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend)
+ std::string CertFile; // PEM certificate chain file path
+ std::string KeyFile; // PEM private key file path
struct
{
@@ -250,6 +342,10 @@ struct HttpServerConfig
bool IsAsyncResponseEnabled = true;
bool IsRequestLoggingEnabled = false;
bool UseExplicitIoThreadPool = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
} HttpSys;
};
@@ -420,7 +516,7 @@ public:
~HttpRpcHandler();
HttpRpcHandler(const HttpRpcHandler&) = delete;
- HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
+ HttpRpcHandler& operator=(const HttpRpcHandler&) = delete;
void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
@@ -436,17 +532,7 @@ private:
bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-struct IHttpStatsProvider
-{
- virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
-};
-
-struct IHttpStatsService
-{
- virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
- virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0;
-};
-
-void http_forcelink(); // internal
+void http_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index e6fea6765..460315faf 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -3,23 +3,50 @@
#pragma once
#include <zencore/logging.h>
+#include <zencore/thread.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+#include <atomic>
#include <map>
+#include <memory>
+#include <thread>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+#include <asio/steady_timer.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-class HttpStatsService : public HttpService, public IHttpStatsService
+class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler
{
public:
- HttpStatsService();
+ /// Construct without an io_context — optionally uses a dedicated push thread
+ /// for WebSocket stats broadcasting.
+ explicit HttpStatsService(bool EnableWebSockets = false);
+
+ /// Construct with an external io_context — uses an asio timer instead
+ /// of a dedicated thread for WebSocket stats broadcasting.
+ /// The caller must ensure the io_context outlives this service and that
+ /// its run loop is active.
+ HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true);
+
~HttpStatsService();
+ void Shutdown();
+
virtual const char* BaseUri() const override;
virtual void HandleRequest(HttpServerRequest& Request) override;
virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
+
private:
LoggerRef m_Log;
HttpRequestRouter m_Router;
@@ -28,6 +55,22 @@ private:
RwLock m_Lock;
std::map<std::string, IHttpStatsProvider*> m_Providers;
+
+ // WebSocket push
+ RwLock m_WsConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_WsConnections;
+ std::atomic<bool> m_PushEnabled{false};
+
+ void BroadcastStats();
+
+ // Thread-based push (when no io_context is provided)
+ std::thread m_PushThread;
+ Event m_PushEvent;
+ void PushThreadFunction();
+
+ // Timer-based push (when an io_context is provided)
+ std::unique_ptr<asio::steady_timer> m_PushTimer;
+ void EnqueuePushTimer();
};
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
new file mode 100644
index 000000000..2ca9b7ab1
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -0,0 +1,83 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/websocket.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio/io_context.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+
+namespace zen {
+
+/**
+ * Callback interface for WebSocket client events
+ *
+ * Separate from the server-side IWebSocketHandler because the caller
+ * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ */
+class IWsClientHandler
+{
+public:
+ virtual ~IWsClientHandler() = default;
+
+ virtual void OnWsOpen() = 0;
+ virtual void OnWsMessage(const WebSocketMessage& Msg) = 0;
+ virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0;
+};
+
+struct HttpWsClientSettings
+{
+ std::string LogCategory = "wsclient";
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+
+ /// Unix domain socket path. When non-empty, connects via this socket
+ /// instead of TCP. The URL host is still used for the Host header.
+ std::filesystem::path UnixSocketPath;
+};
+
+/**
+ * WebSocket client over TCP (ws:// scheme)
+ *
+ * Uses ASIO for async I/O. Two construction modes:
+ * - Internal io_context + background thread (standalone use)
+ * - External io_context (shared event loop, no internal thread)
+ *
+ * Thread-safe for SendText/SendBinary/Close.
+ */
+class HttpWsClient
+{
+public:
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {});
+ HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {});
+
+ ~HttpWsClient();
+
+ HttpWsClient(const HttpWsClient&) = delete;
+ HttpWsClient& operator=(const HttpWsClient&) = delete;
+
+ void Connect();
+ void SendText(std::string_view Text);
+ void SendBinary(std::span<const uint8_t> Data);
+ void Close(uint16_t Code = 1000, std::string_view Reason = {});
+ bool IsOpen() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index c90b840da..1a5068580 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -68,7 +68,7 @@ struct CbAttachmentEntry
struct CbAttachmentReferenceHeader
{
uint64_t PayloadByteOffset = 0;
- uint64_t PayloadByteSize = ~0u;
+ uint64_t PayloadByteSize = ~uint64_t(0);
uint16_t AbsolutePathLength = 0;
// This header will be followed by UTF8 encoded absolute path to backing file
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..bc3293282
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenbase/refcount.h>
+#include <zencore/iobuffer.h>
+
+#include <cstdint>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+enum class WebSocketOpcode : uint8_t
+{
+ kText = 0x1,
+ kBinary = 0x2,
+ kClose = 0x8,
+ kPing = 0x9,
+ kPong = 0xA
+};
+
+struct WebSocketMessage
+{
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ IoBuffer Payload;
+ uint16_t CloseCode = 0;
+};
+
+/**
+ * Represents an active WebSocket connection
+ *
+ * Derived classes implement the actual transport (e.g. ASIO sockets).
+ * Instances are reference-counted so that both the service layer and
+ * the async read/write loop can share ownership.
+ */
+class WebSocketConnection : public RefCounted
+{
+public:
+ virtual ~WebSocketConnection() = default;
+
+ virtual void SendText(std::string_view Text) = 0;
+ virtual void SendBinary(std::span<const uint8_t> Data) = 0;
+ virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0;
+ virtual bool IsOpen() const = 0;
+};
+
+/**
+ * Interface for services that accept WebSocket upgrades
+ *
+ * An HttpService may additionally implement this interface to indicate
+ * it supports WebSocket connections. The HTTP server checks for this
+ * via dynamic_cast when it sees an Upgrade: websocket request.
+ */
+class IWebSocketHandler
+{
+public:
+ virtual ~IWebSocketHandler() = default;
+
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
+ virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index b097a0d3f..3877215a8 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -3,15 +3,57 @@
#include "zenhttp/httpstats.h"
#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
namespace zen {
-HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats"))
+HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats"))
{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushThread = std::thread([this] { PushThreadFunction(); });
+ }
+}
+
+HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats"))
+{
+ if (EnableWebSockets)
+ {
+ m_PushEnabled.store(true);
+ m_PushTimer = std::make_unique<asio::steady_timer>(IoContext);
+ EnqueuePushTimer();
+ }
}
HttpStatsService::~HttpStatsService()
{
+ Shutdown();
+}
+
+void
+HttpStatsService::Shutdown()
+{
+ if (!m_PushEnabled.exchange(false))
+ {
+ return;
+ }
+
+ if (m_PushTimer)
+ {
+ m_PushTimer->cancel();
+ m_PushTimer.reset();
+ }
+
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ m_PushThread.join();
+ }
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); });
}
const char*
@@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro
void
HttpStatsService::HandleRequest(HttpServerRequest& Request)
{
+ ZEN_TRACE_CPU("HttpStatsService::HandleRequest");
using namespace std::literals;
std::string_view Key = Request.RelativeUri();
@@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
}
}
+//////////////////////////////////////////////////////////////////////////
+//
+// IWebSocketHandler
+//
+
+void
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
+ ZEN_INFO("Stats WebSocket client connected");
+
+ m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); });
+
+ // Send initial state immediately
+ if (m_PushThread.joinable())
+ {
+ m_PushEvent.Set();
+ }
+}
+
+void
+HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/)
+{
+ // No client-to-server messages expected
+}
+
+void
+HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason)
+{
+ ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose");
+ ZEN_INFO("Stats WebSocket client disconnected (code {})", Code);
+
+ m_WsConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_WsConnections.erase(It, m_WsConnections.end());
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Stats broadcast
+//
+
+void
+HttpStatsService::BroadcastStats()
+{
+ ZEN_TRACE_CPU("HttpStatsService::BroadcastStats");
+ std::vector<Ref<WebSocketConnection>> Connections;
+ m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; });
+
+ if (Connections.empty())
+ {
+ return;
+ }
+
+ // Collect stats from all providers
+ ExtendableStringBuilder<4096> JsonBuilder;
+ JsonBuilder.Append("{");
+
+ bool First = true;
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ for (auto& [Id, Provider] : m_Providers)
+ {
+ CbObject Stats = Provider->CollectStats();
+ if (!Stats)
+ {
+ continue;
+ }
+
+ if (!First)
+ {
+ JsonBuilder.Append(",");
+ }
+ First = false;
+
+ // Emit as "provider_id": { ... }
+ JsonBuilder.Append("\"");
+ JsonBuilder.Append(Id);
+ JsonBuilder.Append("\":");
+
+ ExtendableStringBuilder<2048> StatsJson;
+ Stats.ToJson(StatsJson);
+ JsonBuilder.Append(StatsJson.ToView());
+ }
+ }
+
+ JsonBuilder.Append("}");
+
+ std::string_view Json = JsonBuilder.ToView();
+ for (auto& Conn : Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Json);
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Thread-based push (fallback when no io_context)
+//
+
+void
+HttpStatsService::PushThreadFunction()
+{
+ SetCurrentThreadName("stats_ws_push");
+
+ while (m_PushEnabled.load())
+ {
+ m_PushEvent.Wait(1000);
+ m_PushEvent.Reset();
+
+ if (!m_PushEnabled.load())
+ {
+ break;
+ }
+
+ BroadcastStats();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Timer-based push (when io_context is provided)
+//
+
+void
+HttpStatsService::EnqueuePushTimer()
+{
+ if (!m_PushTimer)
+ {
+ return;
+ }
+
+ m_PushTimer->expires_after(std::chrono::seconds(1));
+ m_PushTimer->async_wait([this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ return;
+ }
+
+ BroadcastStats();
+ EnqueuePushTimer();
+ });
+}
+
} // namespace zen
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 708238224..9c62c1f2d 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -575,13 +575,21 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
}
else if (AttachmentSize > 0)
{
- // Make a copy of the buffer so the attachments don't reference the entire payload
- IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
- ZEN_ASSERT(AttachmentBufferCopy);
- ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
- AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
+ IoBufferFileReference TestIfFileRef;
+ if (AttachmentBuffer.GetFileReference(TestIfFileRef))
+ {
+ Attachments.emplace_back(CbAttachment(SharedBuffer{std::move(AttachmentBuffer)}, Entry.AttachmentHash));
+ }
+ else
+ {
+ // Make a copy of the buffer so the attachments don't reference the entire payload
+ IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
+ ZEN_ASSERT(AttachmentBufferCopy);
+ ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
+ AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
- Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash));
+ }
}
else
{
@@ -805,6 +813,8 @@ CbPackageReader::Finalize()
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.packageformat");
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -926,6 +936,8 @@ TEST_CASE("CbPackage.LocalRef")
Reader.Finalize();
}
+TEST_SUITE_END();
+
void
forcelink_packageformat()
{
diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp
index a8fb9c3f5..0e3a743c3 100644
--- a/src/zenhttp/security/passwordsecurity.cpp
+++ b/src/zenhttp/security/passwordsecurity.cpp
@@ -76,6 +76,8 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUr
#if ZEN_WITH_TESTS
+TEST_SUITE_BEGIN("http.passwordsecurity");
+
TEST_CASE("passwordsecurity.allowanything")
{
PasswordSecurity Anything({});
@@ -162,6 +164,9 @@ TEST_CASE("passwordsecurity.conflictingunprotecteduris")
"uri #1 ('/free/access')"));
}
}
+
+TEST_SUITE_END();
+
void
passwordsecurity_forcelink()
{
diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h
new file mode 100644
index 000000000..25aeaa24e
--- /dev/null
+++ b/src/zenhttp/servers/asio_socket_traits.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_http {
+
+/**
+ * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets.
+ * SSL sockets need lowest_layer() access and have different shutdown semantics.
+ */
+template<typename SocketType>
+struct SocketTraits
+{
+ /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because
+ /// those bypass the encryption layer. This flag lets templated code fall back
+ /// to reading-into-memory for SSL connections.
+ static constexpr bool IsSslSocket = false;
+
+ static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); }
+
+ static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); }
+};
+
+#if ZEN_USE_OPENSSL
+using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>;
+
+template<>
+struct SocketTraits<SslSocket>
+{
+ static constexpr bool IsSslSocket = true;
+
+ static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SslSocket& S, std::error_code& Ec)
+ {
+ // Best-effort SSL close_notify, then TCP shutdown
+ S.shutdown(Ec);
+ S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec);
+ }
+
+ static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); }
+};
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 1c0ebef90..643f33618 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1,18 +1,22 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpasio.h"
+#include "asio_socket_traits.h"
#include "httptracer.h"
#include <zencore/except.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory/llm.h>
+#include <zencore/system.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zencore/windows.h>
#include <zenhttp/httpserver.h>
#include "httpparser.h"
+#include "wsasio.h"
+#include "wsframecodec.h"
#include <EASTL/fixed_vector.h>
@@ -32,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#include <asio.hpp>
#include <asio/stream_file.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#define ASIO_VERBOSE_TRACE 0
@@ -89,10 +99,10 @@ IsIPv6AvailableSysctl(void)
char buf[16];
if (fgets(buf, sizeof(buf), f))
{
- fclose(f);
// 0 means IPv6 enabled, 1 means disabled
val = atoi(buf);
}
+ fclose(f);
}
return val == 0;
@@ -141,13 +151,23 @@ using namespace std::literals;
struct HttpAcceptor;
struct HttpResponse;
-struct HttpServerConnection;
+template<typename SocketType>
+struct HttpServerConnectionT;
+using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+struct UnixAcceptor;
+using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>;
+#endif
+#if ZEN_USE_OPENSSL
+struct HttpsAcceptor;
+using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>;
+#endif
inline LoggerRef
InitLogger()
{
LoggerRef Logger = logging::Get("asio");
- // Logger.SetLogLevel(logging::level::Trace);
+ // Logger.SetLogLevel(logging::Trace);
return Logger;
}
@@ -173,9 +193,9 @@ Log()
#endif
#if ZEN_USE_TRANSMITFILE
-template<typename Handler>
+template<typename Handler, typename SocketType>
void
-TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
+TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
{
# if ZEN_BUILD_DEBUG
const uint64_t FileSize = FileSizeFromHandle(FileHandle);
@@ -506,11 +526,22 @@ public:
HttpService* RouteRequest(std::string_view Url);
IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
- asio::io_service m_IoService;
- asio::io_service::work m_Work{m_IoService};
- std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
- std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool IsLoopbackOnly() const;
+
+ int GetEffectiveHttpsPort() const;
+
+ asio::io_context m_IoService;
+ asio::executor_work_guard<asio::io_context::executor_type> m_Work{m_IoService.get_executor()};
+ std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor;
+#endif
+#if ZEN_USE_OPENSSL
+ std::unique_ptr<asio::ssl::context> m_SslContext;
+ std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor;
+#endif
+ std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -523,6 +554,11 @@ public:
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
+
+ HttpServer* m_HttpServer = nullptr;
};
/**
@@ -536,7 +572,8 @@ public:
HttpService& Service,
IoBuffer PayloadBuffer,
uint32_t RequestNumber,
- bool IsLocalMachineRequest);
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress);
~HttpAsioServerRequest();
virtual Oid ParseSessionId() const override;
@@ -544,6 +581,7 @@ public:
virtual bool IsLocalMachineRequest() const override;
virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -561,6 +599,8 @@ public:
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
bool m_IsLocalMachineRequest;
+ bool m_AllowZeroCopyFileSend = true;
+ std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -582,6 +622,8 @@ public:
~HttpResponse() = default;
+ void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
+
/**
* Initialize the response for sending a payload made up of multiple blobs
*
@@ -623,7 +665,7 @@ public:
bool ChunkHandled = false;
#if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE
- if (OwnedBuffer.IsWholeFile())
+ if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile())
{
if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef))
{
@@ -738,7 +780,8 @@ public:
return m_Headers;
}
- void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
+ template<typename SocketType>
+ void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
{
ZEN_ASSERT(m_State == State::kInitialized);
@@ -748,10 +791,11 @@ public:
m_SendCb = std::move(Token);
m_State = State::kSending;
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
- void SendNextChunk(asio::ip::tcp::socket& TcpSocket)
+ template<typename SocketType>
+ void SendNextChunk(SocketType& Socket)
{
ZEN_ASSERT(m_State == State::kSending);
@@ -768,12 +812,12 @@ public:
auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); };
- asio::defer(TcpSocket.get_executor(), std::move(CompletionToken));
+ asio::defer(Socket.get_executor(), std::move(CompletionToken));
return;
}
- auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) {
+ auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) {
ZEN_ASSERT(m_State == State::kSending);
m_TotalBytesSent += ByteCount;
@@ -784,7 +828,7 @@ public:
}
else
{
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
};
@@ -798,25 +842,21 @@ public:
Io.Ref.FileRef.FileChunkSize);
#if ZEN_USE_TRANSMITFILE
- TransmitFileAsync(TcpSocket,
+ TransmitFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize),
OnCompletion);
+ return;
#elif ZEN_USE_ASYNC_SENDFILE
- SendFileAsync(TcpSocket,
+ SendFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
Io.Ref.FileRef.FileChunkSize,
64 * 1024,
OnCompletion);
-#else
- // This should never occur unless we compile with one
- // of the options above
- ZEN_WARN("invalid file reference in response");
-#endif
-
return;
+#endif
}
// Send as many consecutive non-file references as possible in one asio operation
@@ -837,7 +877,7 @@ public:
++m_IoVecCursor;
}
- asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
+ asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
}
private:
@@ -850,12 +890,13 @@ private:
kFailed
};
- uint32_t m_RequestNumber = 0;
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- State m_State = State::kUninitialized;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ uint32_t m_RequestNumber = 0;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ bool m_AllowZeroCopyFileSend = true;
+ State m_State = State::kUninitialized;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -882,12 +923,13 @@ private:
//////////////////////////////////////////////////////////////////////////
-struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection>
+template<typename SocketType>
+struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>>
{
- HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
- ~HttpServerConnection();
+ HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket);
+ ~HttpServerConnectionT();
- std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+ std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); }
// HttpConnectionBase implementation
@@ -938,6 +980,7 @@ private:
void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop);
void CloseConnection();
+ void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {});
HttpAsioServerImpl& m_Server;
asio::streambuf m_RequestBuffer;
@@ -948,12 +991,13 @@ private:
RwLock m_ActiveResponsesLock;
std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<SocketType> m_Socket;
};
std::atomic<uint32_t> g_ConnectionIdCounter{0};
-HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket)
: m_Server(Server)
, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
, m_Socket(std::move(Socket))
@@ -961,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq
ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
}
-HttpServerConnection::~HttpServerConnection()
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::~HttpServerConnectionT()
{
RwLock::ExclusiveLockScope _(m_ActiveResponsesLock);
ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
}
+template<typename SocketType>
void
-HttpServerConnection::HandleNewRequest()
+HttpServerConnectionT<SocketType>::HandleNewRequest()
{
EnqueueRead();
}
+template<typename SocketType>
void
-HttpServerConnection::TerminateConnection()
+HttpServerConnectionT<SocketType>::TerminateConnection()
{
if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated)
{
@@ -987,12 +1034,13 @@ HttpServerConnection::TerminateConnection()
// Terminating, we don't care about any errors when closing socket
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_both, Ec);
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
}
+template<typename SocketType>
void
-HttpServerConnection::EnqueueRead()
+HttpServerConnectionT<SocketType>::EnqueueRead()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1013,8 +1061,9 @@ HttpServerConnection::EnqueueRead()
[Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
}
+template<typename SocketType>
void
-HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1038,6 +1087,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+ m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
m_RequestCounter.load(std::memory_order_relaxed),
@@ -1070,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+template<typename SocketType>
void
-HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
- [[maybe_unused]] std::size_t ByteCount,
- [[maybe_unused]] uint32_t RequestNumber,
- HttpResponse* ResponseToPop)
+HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec,
+ [[maybe_unused]] std::size_t ByteCount,
+ [[maybe_unused]] uint32_t RequestNumber,
+ HttpResponse* ResponseToPop)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1091,6 +1143,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
return;
}
+ m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed);
+
ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}",
m_ConnectionId,
RequestNumber,
@@ -1126,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
}
}
+template<typename SocketType>
void
-HttpServerConnection::CloseConnection()
+HttpServerConnectionT<SocketType>::CloseConnection()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1139,29 +1194,113 @@ HttpServerConnection::CloseConnection()
m_RequestState = RequestState::kDone;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
}
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket close ERROR, reason '{}'", Ec.message());
}
}
+template<typename SocketType>
+void
+HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
+{
+ ExtendableStringBuilder<256> ResponseBuilder;
+ ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
+ if (!Headers.empty())
+ {
+ ResponseBuilder << Headers;
+ }
+ if (!m_RequestData.IsKeepAlive())
+ {
+ ResponseBuilder << "Connection: close\r\n";
+ }
+ ResponseBuilder << "\r\n";
+ if (!Body.empty())
+ {
+ ResponseBuilder << Body;
+ }
+ auto ResponseView = ResponseBuilder.ToView();
+ IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
+ auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
+ asio::async_write(
+ *m_Socket,
+ Buffer,
+ [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
+ });
+}
+
+template<typename SocketType>
void
-HttpServerConnection::HandleRequest()
+HttpServerConnectionT<SocketType>::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ // WebSocket upgrade detection must happen before the keep-alive check below,
+ // because Upgrade requests have "Connection: Upgrade" which the HTTP parser
+ // treats as non-keep-alive, causing a premature shutdown of the receive side.
+ if (m_RequestData.IsWebSocketUpgrade())
+ {
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service);
+ if (WsHandler && !m_RequestData.SecWebSocketKey().empty())
+ {
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey());
+
+ auto ResponseStr = std::make_shared<std::string>();
+ ResponseStr->reserve(256);
+ ResponseStr->append(
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Accept: ");
+ ResponseStr->append(AcceptKey);
+ ResponseStr->append("\r\n\r\n");
+
+ // Send the 101 response on the current socket, then hand the socket off
+ // to a WsAsioConnectionT for the WebSocket protocol.
+ asio::async_write(
+ *m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ using WsConnType = WsAsioConnectionT<SocketType>;
+ Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
+
+ m_RequestState = RequestState::kDone;
+ return;
+ }
+ }
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
@@ -1179,16 +1318,45 @@ HttpServerConnection::HandleRequest()
{
ZEN_TRACE_CPU("asio::HandleRequest");
- bool IsLocalConnection = m_Socket->local_endpoint().address() == m_Socket->remote_endpoint().address();
+ m_Server.m_HttpServer->MarkRequest();
+
+ bool IsLocalConnection = true;
+ std::string RemoteAddress;
+
+ if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>)
+ {
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#if ZEN_USE_OPENSSL
+ else if constexpr (std::is_same_v<SocketType, SslSocket>)
+ {
+ auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint();
+ IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#endif
+ else
+ {
+ RemoteAddress = "unix";
+ }
+
+ HttpAsioServerRequest Request(m_RequestData,
+ *Service,
+ m_RequestData.Body(),
+ RequestNumber,
+ IsLocalConnection,
+ std::move(RemoteAddress));
- HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection);
+ Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket;
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server.m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server.m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -1310,63 +1478,45 @@ HttpServerConnection::HandleRequest()
}
}
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect();
+ if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty()))
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "\r\n"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Connection: close\r\n"
- "\r\n"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ ExtendableStringBuilder<128> Headers;
+ Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n";
+ SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView());
+ }
+ else if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ SendInlineResponse(RequestNumber, "404 NOT FOUND"sv);
}
else
{
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "\r\n"
- "No suitable route found"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "Connection: close\r\n"
- "\r\n"
- "No suitable route found"sv;
- }
-
- asio::async_write(*m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
- });
+ SendInlineResponse(RequestNumber,
+ "404 NOT FOUND"sv,
+ "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv,
+ "No suitable route found"sv);
}
}
//////////////////////////////////////////////////////////////////////////
+// Base class for TCP acceptors that handles socket setup, port binding
+// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support.
+// Subclasses only need to implement OnAccept() to handle new connections.
-struct HttpAcceptor
+struct TcpAcceptorBase
{
- HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ TcpAcceptorBase(HttpAsioServerImpl& Server,
+ asio::io_context& IoService,
+ uint16_t BasePort,
+ bool ForceLoopback,
+ bool AllowPortProbing,
+ std::string_view Label)
: m_Server(Server)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
, m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4())
+ , m_Label(Label)
{
const bool IsUsingIPv6 = IsIPv6Capable();
if (!IsUsingIPv6)
@@ -1375,93 +1525,66 @@ struct HttpAcceptor
}
#if ZEN_PLATFORM_WINDOWS
- // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
#else // ZEN_PLATFORM_WINDOWS
- m_Acceptor.set_option(asio::socket_base::reuse_address(false));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false));
+ // Allow binding to a port in TIME_WAIT so the server can restart immediately
+ // after a previous instance exits. On Linux this does not allow two processes
+ // to actively listen on the same port simultaneously.
+ m_Acceptor.set_option(asio::socket_base::reuse_address(true));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true));
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- std::string BoundBaseUrl;
if (IsUsingIPv6)
{
- BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
+ BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
}
else
{
- ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only");
-
- BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
+ ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label);
+ BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
}
+ }
- if (!IsValid())
- {
- return;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor.native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-
- if (m_UseAlternateProtocolAcceptor)
- {
- NativeSocket = m_AlternateProtocolAcceptor.native_handle();
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
- }
-#endif
- m_Acceptor.listen();
+ virtual ~TcpAcceptorBase()
+ {
+ m_Acceptor.close();
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.listen();
+ m_AlternateProtocolAcceptor.close();
}
-
- ZEN_INFO("Started asio server at '{}", BoundBaseUrl);
}
- ~HttpAcceptor()
+ void Start()
{
- m_Acceptor.close();
+ ZEN_ASSERT(!m_IsStopped);
+ InitAcceptLoop(m_Acceptor);
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.close();
+ InitAcceptLoop(m_AlternateProtocolAcceptor);
}
}
+ void StopAccepting() { m_IsStopped = true; }
+
+ uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
+ bool IsValid() const { return m_IsValid; }
+
+protected:
+ /// Called for each accepted TCP socket. Subclasses create the appropriate connection type.
+ virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0;
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_context& m_IoService;
+
+private:
template<typename AddressType>
- std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
{
uint16_t EffectivePort = BasePort;
@@ -1488,7 +1611,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback())
{
- ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort);
+ ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort);
BindAddress = AddressType::loopback();
@@ -1502,7 +1625,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::address_in_use)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1518,7 +1641,8 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ m_Label,
BindErrorCode.message());
EffectivePort = 0;
@@ -1534,7 +1658,7 @@ struct HttpAcceptor
{
for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1542,14 +1666,13 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
-
- return 0;
+ ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message());
+ return;
}
if (EffectivePort != BasePort)
{
- ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort);
+ ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort);
}
if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
@@ -1559,54 +1682,64 @@ struct HttpAcceptor
// IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does
// IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1)
- m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode);
+ asio::error_code AltEc;
+ m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc);
- if (BindErrorCode)
+ if (AltEc)
{
- ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort);
+ ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort);
}
else
{
m_UseAlternateProtocolAcceptor = true;
- ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts",
- "localhost",
- EffectivePort);
}
}
}
- m_IsValid = true;
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
- if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
- {
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort);
- }
- else
+ if (m_UseAlternateProtocolAcceptor)
{
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort);
+ NativeSocket = m_AlternateProtocolAcceptor.native_handle();
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
}
- }
-
- void Start()
- {
- ZEN_MEMSCOPE(GetHttpasioTag());
+#endif
- ZEN_ASSERT(!m_IsStopped);
- InitAcceptInternal(m_Acceptor);
+ m_Acceptor.listen();
if (m_UseAlternateProtocolAcceptor)
{
- InitAcceptInternal(m_AlternateProtocolAcceptor);
+ m_AlternateProtocolAcceptor.listen();
}
- }
- void StopAccepting() { m_IsStopped = true; }
-
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
-
- bool IsValid() const { return m_IsValid; }
+ m_IsValid = true;
+ ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port());
+ }
-private:
- void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor)
+ void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor)
{
auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
@@ -1614,29 +1747,19 @@ private:
Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- Acceptor.local_endpoint().address().to_string(),
- Acceptor.local_endpoint().port(),
- Ec.message());
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message());
+ }
}
else
{
- // New connection established, pass socket ownership into connection object
- // and initiate request handling loop. The connection lifetime is
- // managed by the async read/write loop by passing the shared
- // reference to the callbacks.
-
- Socket->set_option(asio::ip::tcp::no_delay(true));
- Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
- Conn->HandleNewRequest();
+ OnAccept(std::move(Socket));
}
if (!m_IsStopped.load())
{
- InitAcceptInternal(Acceptor);
+ InitAcceptLoop(Acceptor);
}
else
{
@@ -1644,33 +1767,218 @@ private:
Acceptor.close(CloseEc);
if (CloseEc)
{
- ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
+ ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message());
}
}
});
}
- HttpAsioServerImpl& m_Server;
- asio::io_service& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
asio::ip::tcp::acceptor m_AlternateProtocolAcceptor;
bool m_UseAlternateProtocolAcceptor{false};
bool m_IsValid{false};
std::atomic<bool> m_IsStopped{false};
+ std::string_view m_Label;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor final : TcpAcceptorBase
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_context& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP")
+ {
+ }
+
+ int GetAcceptPort() const { return GetPort(); }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+};
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+
+//////////////////////////////////////////////////////////////////////////
+
+struct UnixAcceptor
+{
+ UnixAcceptor(HttpAsioServerImpl& Server, asio::io_context& IoService, const std::string& SocketPath)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService)
+ , m_SocketPath(SocketPath)
+ {
+ // Remove any stale socket file from a previous run
+ std::filesystem::remove(m_SocketPath);
+
+ asio::local::stream_protocol::endpoint Endpoint(m_SocketPath);
+
+ asio::error_code Ec;
+ m_Acceptor.open(Endpoint.protocol(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to open unix domain socket: {}", Ec.message());
+ return;
+ }
+
+ m_Acceptor.bind(Endpoint, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_IsValid = true;
+ ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath);
+ }
+
+ ~UnixAcceptor()
+ {
+ asio::error_code Ec;
+ m_Acceptor.close(Ec);
+ std::filesystem::remove(m_SocketPath);
+ }
+
+ void Start()
+ {
+ ZEN_ASSERT(!m_IsStopped);
+ InitAccept();
+ }
+
+ void StopAccepting() { m_IsStopped = true; }
+
+ bool IsValid() const { return m_IsValid; }
+
+private:
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService);
+ asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ std::error_code CloseEc;
+ m_Acceptor.close(CloseEc);
+ }
+ });
+ }
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_context& m_IoService;
+ asio::local::stream_protocol::acceptor m_Acceptor;
+ std::string m_SocketPath;
+ bool m_IsValid{false};
+ std::atomic<bool> m_IsStopped{false};
+};
+
+#endif // ASIO_HAS_LOCAL_SOCKETS
+
+#if ZEN_USE_OPENSSL
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpsAcceptor final : TcpAcceptorBase
+{
+ HttpsAcceptor(HttpAsioServerImpl& Server,
+ asio::io_context& IoService,
+ asio::ssl::context& SslContext,
+ uint16_t Port,
+ bool ForceLoopback,
+ bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS")
+ , m_SslContext(SslContext)
+ {
+ }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ // Wrap accepted TCP socket in an SSL stream and perform the handshake
+ auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext);
+
+ SslSocket& SslRef = *SslSocketPtr;
+ SslRef.async_handshake(asio::ssl::stream_base::server,
+ [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable {
+ if (HandshakeEc)
+ {
+ ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message());
+ std::error_code Ec;
+ SslSocket->lowest_layer().close(Ec);
+ return;
+ }
+
+ auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket));
+ Conn->HandleNewRequest();
+ });
+ }
+
+private:
+ asio::ssl::context& m_SslContext;
};
+#endif // ZEN_USE_OPENSSL
+
+int
+HttpAsioServerImpl::GetEffectiveHttpsPort() const
+{
+#if ZEN_USE_OPENSSL
+ return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0;
+#else
+ return 0;
+#endif
+}
+
//////////////////////////////////////////////////////////////////////////
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
HttpService& Service,
IoBuffer PayloadBuffer,
uint32_t RequestNumber,
- bool IsLocalMachineRequest)
+ bool IsLocalMachineRequest,
+ std::string RemoteAddress)
: HttpServerRequest(Service)
, m_Request(Request)
, m_RequestNumber(RequestNumber)
, m_PayloadBuffer(std::move(PayloadBuffer))
, m_IsLocalMachineRequest(IsLocalMachineRequest)
+, m_RemoteAddress(std::move(RemoteAddress))
{
const int PrefixLength = Service.UriPrefixLength();
@@ -1749,6 +2057,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const
}
std::string_view
+HttpAsioServerRequest::GetRemoteAddress() const
+{
+ return m_RemoteAddress;
+}
+
+std::string_view
HttpAsioServerRequest::GetAuthorizationHeader() const
{
return m_Request.AuthorizationHeader();
@@ -1768,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -1781,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -1791,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
@@ -1840,15 +2157,63 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config)
ZEN_INFO("starting asio http with {} service threads", Config.ThreadCount);
- m_Acceptor.reset(
- new asio_http::HttpAcceptor(*this, m_IoService, Port, Config.ForceLoopback, /*AllowPortProbing */ !Config.IsDedicatedServer));
+ if (!Config.NoNetwork)
+ {
+ m_Acceptor.reset(
+ new asio_http::HttpAcceptor(*this, m_IoService, Port, Config.ForceLoopback, /*AllowPortProbing */ !Config.IsDedicatedServer));
+
+ if (!m_Acceptor->IsValid())
+ {
+ return 0;
+ }
- if (!m_Acceptor->IsValid())
+ m_Acceptor->Start();
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!Config.UnixSocketPath.empty())
{
- return 0;
+ m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath));
+
+ if (m_UnixAcceptor->IsValid())
+ {
+ m_UnixAcceptor->Start();
+ }
+ else
+ {
+ m_UnixAcceptor.reset();
+ }
}
+#endif
+
+#if ZEN_USE_OPENSSL
+ if (!Config.NoNetwork && !Config.CertFile.empty() && !Config.KeyFile.empty())
+ {
+ m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server);
+ m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1);
+ m_SslContext->use_certificate_chain_file(Config.CertFile);
+ m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem);
- m_Acceptor->Start();
+ ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile);
+
+ m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this,
+ m_IoService,
+ *m_SslContext,
+ gsl::narrow<uint16_t>(Config.HttpsPort),
+ Config.ForceLoopback,
+ /*AllowPortProbing*/ !Config.IsDedicatedServer));
+
+ if (m_HttpsAcceptor->IsValid())
+ {
+ m_HttpsAcceptor->Start();
+ }
+ else
+ {
+ m_HttpsAcceptor.reset();
+ }
+ }
+#endif
// This should consist of a set of minimum threads and grow on demand to
// meet concurrency needs? Right now we end up allocating a large number
@@ -1881,12 +2246,18 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config)
});
}
- ZEN_INFO("asio http started in {} mode, using {} threads on port {}",
- Config.IsDedicatedServer ? "DEDICATED" : "NORMAL",
- Config.ThreadCount,
- m_Acceptor->GetAcceptPort());
+ if (m_Acceptor)
+ {
+ ZEN_INFO("asio http started in {} mode, using {} threads on port {}",
+ Config.IsDedicatedServer ? "DEDICATED" : "NORMAL",
+ Config.ThreadCount,
+ m_Acceptor->GetAcceptPort());
- return m_Acceptor->GetAcceptPort();
+ return m_Acceptor->GetAcceptPort();
+ }
+
+ ZEN_INFO("asio http started in no-network mode, using {} threads (unix socket only)", Config.ThreadCount);
+ return Port;
}
void
@@ -1898,6 +2269,18 @@ HttpAsioServerImpl::Stop()
{
m_Acceptor->StopAccepting();
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixAcceptor)
+ {
+ m_UnixAcceptor->StopAccepting();
+ }
+#endif
+#if ZEN_USE_OPENSSL
+ if (m_HttpsAcceptor)
+ {
+ m_HttpsAcceptor->StopAccepting();
+ }
+#endif
m_IoService.stop();
for (auto& Thread : m_ThreadPool)
{
@@ -1907,7 +2290,23 @@ HttpAsioServerImpl::Stop()
}
}
m_ThreadPool.clear();
+
+ // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket
+ // connections) so that their captured Ref<> pointers are released while the
+ // io_context and its epoll reactor are still alive. Without this, sockets
+ // held by external code (e.g. IWebSocketHandler connection lists) can outlive
+ // the reactor and crash during deregistration.
+ m_IoService.restart();
+ m_IoService.poll();
+
m_Acceptor.reset();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ m_UnixAcceptor.reset();
+#endif
+#if ZEN_USE_OPENSSL
+ m_HttpsAcceptor.reset();
+ m_SslContext.reset();
+#endif
}
void
@@ -1975,6 +2374,12 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
return RequestFilter->FilterRequest(Request);
}
+bool
+HttpAsioServerImpl::IsLoopbackOnly() const
+{
+ return m_Acceptor && m_Acceptor->IsLoopbackOnly();
+}
+
} // namespace zen::asio_http
//////////////////////////////////////////////////////////////////////////
@@ -1987,12 +2392,15 @@ public:
HttpAsioServer(const AsioConfig& Config);
~HttpAsioServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
private:
Event m_ShutdownEvent;
@@ -2006,6 +2414,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config)
: m_InitialConfig(Config)
, m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
{
+ m_Impl->m_HttpServer = this;
ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser));
}
@@ -2064,9 +2473,51 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config);
+#if ZEN_USE_OPENSSL
+ if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(EffectiveHttpsPort);
+ }
+#endif
+
return m_BasePort;
}
+std::string
+HttpAsioServer::OnGetExternalHost() const
+{
+ if (m_Impl->IsLoopbackOnly())
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_context IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesReceived() const
+{
+ return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpAsioServer::GetTotalBytesSent() const
+{
+ return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpAsioServer::OnRun(bool IsInteractive)
{
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index 3ec1141a7..21d10170e 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -11,6 +11,13 @@ struct AsioConfig
unsigned int ThreadCount = 0;
bool ForceLoopback = false;
bool IsDedicatedServer = false;
+ bool NoNetwork = false;
+ std::string UnixSocketPath;
+#if ZEN_USE_OPENSSL
+ int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS
+ std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled)
+ std::string KeyFile; // PEM private key file
+#endif
};
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 310ac9dc0..584e06cbf 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -117,6 +117,16 @@ HttpMultiServer::OnClose()
}
}
+std::string
+HttpMultiServer::OnGetExternalHost() const
+{
+ if (!m_Servers.empty())
+ {
+ return std::string(m_Servers.front()->GetExternalHost());
+ }
+ return HttpServer::OnGetExternalHost();
+}
+
void
HttpMultiServer::AddServer(Ref<HttpServer> Server)
{
diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h
index 1897587a9..97699828a 100644
--- a/src/zenhttp/servers/httpmulti.h
+++ b/src/zenhttp/servers/httpmulti.h
@@ -15,12 +15,13 @@ public:
HttpMultiServer();
~HttpMultiServer();
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool IsInteractiveSession) override;
- virtual void OnRequestExit() override;
- virtual void OnClose() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool IsInteractiveSession) override;
+ virtual void OnRequestExit() override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
void AddServer(Ref<HttpServer> Server);
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index f0485aa25..918b55dc6 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -12,14 +12,17 @@ namespace zen {
using namespace std::literals;
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
-static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
+static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv);
+static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv);
+static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -143,45 +146,62 @@ HttpRequestParser::ParseCurrentHeader()
const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
- if (HeaderHash == HashContentLength)
+ switch (HeaderHash)
{
- m_ContentLengthHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashAuthorization)
- {
- m_AuthorizationHeaderIndex = CurrentHeaderIndex;
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::TryFromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = CurrentHeaderIndex;
+ case HashContentLength:
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAccept:
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashContentType:
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashAuthorization:
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSession:
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
+ break;
+
+ case HashRequest:
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ break;
+
+ case HashExpect:
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ break;
+
+ case HashRange:
+ m_RangeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashUpgrade:
+ m_UpgradeHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketKey:
+ m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ case HashSecWebSocketVersion:
+ m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex;
+ break;
+
+ default:
+ break;
}
}
@@ -225,13 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl)
NormalizedUrl.reserve(UrlLength);
NormalizedUrl.append(Url, UrlIndex);
}
-
- // NOTE: this check is redundant given the enclosing if,
- // need to verify the intent of this code
- if (!LastCharWasSeparator)
- {
- NormalizedUrl.push_back('/');
- }
}
else if (!NormalizedUrl.empty())
{
@@ -361,14 +374,18 @@ HttpRequestParser::ResetState()
m_HeaderEntries.clear();
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_AuthorizationHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
+ m_UpgradeHeaderIndex = -1;
+ m_SecWebSocketKeyHeaderIndex = -1;
+ m_SecWebSocketVersionHeaderIndex = -1;
+ m_RequestVerb = HttpVerb::kGet;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
m_HeaderData.clear();
m_NormalizedUrl.clear();
@@ -425,4 +442,21 @@ HttpRequestParser::OnMessageComplete()
}
}
+bool
+HttpRequestParser::IsWebSocketUpgrade() const
+{
+ std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex);
+ if (Upgrade.empty())
+ {
+ return false;
+ }
+
+ // Case-insensitive check for "websocket"
+ if (Upgrade.size() != 9)
+ {
+ return false;
+ }
+ return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
+}
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index ff56ca970..23ad9d8fb 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -48,6 +48,10 @@ struct HttpRequestParser
std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+ std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); }
+ std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); }
+ bool IsWebSocketUpgrade() const;
+
private:
struct HeaderRange
{
@@ -86,7 +90,10 @@ private:
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
int8_t m_AuthorizationHeaderIndex;
- HttpVerb m_RequestVerb;
+ int8_t m_UpgradeHeaderIndex;
+ int8_t m_SecWebSocketKeyHeaderIndex;
+ int8_t m_SecWebSocketVersionHeaderIndex;
+ HttpVerb m_RequestVerb = HttpVerb::kGet;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
int m_RequestId = -1;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 8564826d6..a1bb719c8 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -123,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- int m_BasePort;
+ int m_BasePort = 0;
HttpServerTracer m_RequestTracer;
@@ -147,7 +147,7 @@ public:
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
// As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
- virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ bool IsLocalMachineRequest() const override { return false; }
virtual std::string_view GetAuthorizationHeader() const override;
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
@@ -294,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug
ConnectionName = "anonymous";
}
- ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName);
+ ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName);
}
uint32_t
@@ -378,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest()
{
ZEN_TRACE_CPU("http_plugin::HandleRequest");
+ m_Server->MarkRequest();
+
HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server->m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -480,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest()
const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
ResponseBuffers);
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 4406d0619..eaf080960 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -12,6 +12,7 @@
#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
+#include <zencore/system.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zenhttp/packageformat.h>
@@ -25,7 +26,9 @@
# include <zencore/workthreadpool.h>
# include "iothreadpool.h"
+# include <atomic>
# include <http.h>
+# include <asio.hpp> // for resolving addresses for GetExternalHost
namespace zen {
@@ -85,6 +88,8 @@ class HttpSysServerRequest;
class HttpSysServer : public HttpServer
{
friend class HttpSysTransaction;
+ friend class HttpMessageResponseRequest;
+ friend struct InitialRequestHandler;
public:
explicit HttpSysServer(const HttpSysConfig& Config);
@@ -92,12 +97,15 @@ public:
// HttpServer interface implementation
- virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
- virtual void OnRun(bool TestMode) override;
- virtual void OnRequestExit() override;
- virtual void OnRegisterService(HttpService& Service) override;
- virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
- virtual void OnClose() override;
+ virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnRun(bool TestMode) override;
+ virtual void OnRequestExit() override;
+ virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
+ virtual void OnClose() override;
+ virtual std::string OnGetExternalHost() const override;
+ virtual uint64_t GetTotalBytesReceived() const override;
+ virtual uint64_t GetTotalBytesSent() const override;
WorkerThreadPool& WorkPool();
@@ -108,6 +116,12 @@ public:
private:
int InitializeServer(int BasePort);
+ bool CreateSessionAndUrlGroup();
+ bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris);
+ int RegisterHttpUrls(int BasePort);
+ bool RegisterHttpsUrls();
+ bool CreateRequestQueue(int EffectivePort);
+ bool SetupIoCompletionPort();
void Cleanup();
void StartServer();
@@ -117,6 +131,9 @@ private:
void RegisterService(const char* Endpoint, HttpService& Service);
void UnregisterService(const char* Endpoint, HttpService& Service);
+ bool BindSslCertificate(int Port);
+ void UnbindSslCertificate();
+
private:
LoggerRef m_Log;
LoggerRef m_RequestLog;
@@ -130,10 +147,13 @@ private:
std::unique_ptr<WinIoThreadPool> m_IoThreadPool;
bool m_IoThreadPoolIsWinTp = true;
- RwLock m_AsyncWorkPoolInitLock;
- WorkerThreadPool* m_AsyncWorkPool = nullptr;
+ RwLock m_AsyncWorkPoolInitLock;
+ std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/
+ bool m_DidAutoBindCert = false;
+ int m_HttpsPort = 0;
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
HANDLE m_RequestQueueHandle = 0;
@@ -146,6 +166,9 @@ private:
RwLock m_RequestFilterLock;
std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+
+ std::atomic<uint64_t> m_TotalBytesReceived{0};
+ std::atomic<uint64_t> m_TotalBytesSent{0};
};
} // namespace zen
@@ -153,6 +176,10 @@ private:
#if ZEN_WITH_HTTPSYS
+# include "httpsys_iocontext.h"
+# include "wshttpsys.h"
+# include "wsframecodec.h"
+
# include <conio.h>
# include <mstcpip.h>
# pragma comment(lib, "httpapi.lib")
@@ -322,8 +349,9 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
- virtual bool IsLocalMachineRequest() const;
+ virtual bool IsLocalMachineRequest() const override;
virtual std::string_view GetAuthorizationHeader() const override;
+ virtual std::string_view GetRemoteAddress() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -339,11 +367,12 @@ public:
HttpSysServerRequest(const HttpSysServerRequest&) = delete;
HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+ mutable ExtendableStringBuilder<64> m_RemoteAddress;
};
/** HTTP transaction
@@ -378,7 +407,7 @@ public:
void StartIo();
void CancelIo();
HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; }
inline HttpSysServer& Server() { return m_HttpServer; }
inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
@@ -395,8 +424,8 @@ public:
};
private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
+ HttpSysIoContext m_IoContext{};
+ HttpSysServer& m_HttpServer;
// Tracks which handler is due to handle the next I/O completion event
HttpSysRequestHandler* m_CompletionHandler = nullptr;
@@ -436,6 +465,8 @@ public:
inline uint16_t GetResponseCode() const { return m_ResponseCode; }
inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
+ void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
+
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
@@ -445,6 +476,7 @@ private:
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
+ std::string m_LocationHeader;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -585,7 +617,7 @@ HttpMessageResponseRequest::SuppressResponseBody()
HttpSysRequestHandler*
HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
- ZEN_UNUSED(NumberOfBytesTransferred);
+ Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
if (IoResult != NO_ERROR)
{
@@ -699,6 +731,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+ // Location header (for redirects)
+
+ if (!m_LocationHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation];
+ LocationHeader->pRawValue = m_LocationHeader.data();
+ LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -900,7 +941,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}",
+ GetSystemErrorAsString(IoResult),
+ IoResult,
+ NumberOfBytesTransferred);
return this;
}
@@ -1035,8 +1079,10 @@ HttpSysServer::~HttpSysServer()
ZEN_ERROR("~HttpSysServer() called without calling Close() first");
}
- delete m_AsyncWorkPool;
+ auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed);
m_AsyncWorkPool = nullptr;
+
+ delete WorkPool;
}
void
@@ -1051,36 +1097,63 @@ HttpSysServer::OnClose()
}
}
-int
-HttpSysServer::InitializeServer(int BasePort)
+bool
+HttpSysServer::CreateSessionAndUrlGroup()
{
- ZEN_MEMSCOPE(GetHttpsysTag());
-
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
+ return true;
+}
+
+bool
+HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris)
+{
+ using namespace std::literals;
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrl;
+ LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (Result == NO_ERROR)
+ {
+ ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl));
+ OutUris.push_back(LocalUrl.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return !OutUris.empty();
+}
+
+int
+HttpSysServer::RegisterHttpUrls(int BasePort)
+{
+ using namespace std::literals;
+
m_BaseUris.clear();
const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer;
@@ -1088,6 +1161,11 @@ HttpSysServer::InitializeServer(int BasePort)
int EffectivePort = BasePort;
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ ULONG Result;
+
if (m_InitialConfig.ForceLoopback)
{
// Force trigger of opening using local port
@@ -1100,7 +1178,9 @@ HttpSysServer::InitializeServer(int BasePort)
if ((Result == ERROR_SHARING_VIOLATION))
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
@@ -1122,7 +1202,9 @@ HttpSysServer::InitializeServer(int BasePort)
{
for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result);
+ ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying",
+ EffectivePort,
+ GetSystemErrorAsString(Result));
Sleep(500);
Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
}
@@ -1139,11 +1221,11 @@ HttpSysServer::InitializeServer(int BasePort)
{
if (AllowLocalOnly)
{
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
+ // If we can't register the wildcard path, we fall back to local paths.
+ // Local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
// during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
+ // port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
if (!m_InitialConfig.ForceLoopback)
@@ -1157,17 +1239,18 @@ HttpSysServer::InitializeServer(int BasePort)
const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ bool ShouldRetryNextPort = true;
+ for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset)
{
- EffectivePort = BasePort + (PortOffset * 100);
+ EffectivePort = BasePort + (PortOffset * 100);
+ ShouldRetryNextPort = false;
for (const std::u8string_view Host : Hosts)
{
WideStringBuilder<64> LocalUrlPath;
LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
if (InternalResult == NO_ERROR)
{
@@ -1175,11 +1258,25 @@ HttpSysServer::InitializeServer(int BasePort)
m_BaseUris.push_back(LocalUrlPath.c_str());
}
+ else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
+ {
+ // Port may be owned by another process's wildcard registration (access denied)
+ // or actively in use (sharing violation) — retry on a different port
+ ShouldRetryNextPort = true;
+ }
else
{
- break;
+ ZEN_WARN("Failed to register local handler '{}': {} ({:#x})",
+ WideToUtf8(LocalUrlPath),
+ GetSystemErrorAsString(InternalResult),
+ InternalResult);
}
}
+
+ if (!m_BaseUris.empty())
+ {
+ break;
+ }
}
}
else
@@ -1193,29 +1290,123 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- if (m_BaseUris.empty())
+ if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0)
{
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
+ WideToUtf8(WildcardUrlPath),
+ GetSystemErrorAsString(Result),
+ Result);
return 0;
}
+ return EffectivePort;
+}
+
+bool
+HttpSysServer::RegisterHttpsUrls()
+{
+ using namespace std::literals;
+
+ const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer;
+ const int HttpsPort = m_InitialConfig.HttpsPort;
+
+ // If HTTPS-only mode, remove HTTP URLs and clear base URIs
+ if (m_InitialConfig.HttpsOnly)
+ {
+ for (const std::wstring& Uri : m_BaseUris)
+ {
+ HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0);
+ }
+ m_BaseUris.clear();
+ }
+
+ // Auto-bind certificate if thumbprint is provided
+ if (!m_InitialConfig.CertThumbprint.empty())
+ {
+ if (!BindSslCertificate(HttpsPort))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort);
+ }
+
+ // Register HTTPS URLs using same pattern as HTTP
+
+ WideStringBuilder<64> HttpsWildcard;
+ HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv;
+
+ ULONG HttpsResult = NO_ERROR;
+
+ if (m_InitialConfig.ForceLoopback)
+ {
+ HttpsResult = ERROR_ACCESS_DENIED;
+ }
+ else
+ {
+ HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ if (HttpsResult == NO_ERROR)
+ {
+ m_HttpsBaseUris.push_back(HttpsWildcard.c_str());
+ }
+ else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly)
+ {
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register HTTPS handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.",
+ WideToUtf8(HttpsWildcard));
+ }
+
+ RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris);
+ }
+ else if (HttpsResult != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})",
+ WideToUtf8(HttpsWildcard),
+ GetSystemErrorAsString(HttpsResult),
+ HttpsResult);
+ return false;
+ }
+
+ if (m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort);
+ return false;
+ }
+
+ m_HttpsPort = HttpsPort;
+ return true;
+}
+
+bool
+HttpSysServer::CreateRequestQueue(int EffectivePort)
+{
HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
WideStringBuilder<64> QueueName;
QueueName << "zenserver_" << EffectivePort;
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ QueueName.c_str(),
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
+ ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ QueueName.c_str(),
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
- return 0;
+ return false;
}
HttpBindingInfo.Flags.Present = 1;
@@ -1225,9 +1416,12 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+ ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})",
+ WideToUtf8(m_BaseUris.front()),
+ GetSystemErrorAsString(Result),
+ Result);
- return 0;
+ return false;
}
// Configure rejection method. Default is to drop the connection, it's better if we
@@ -1257,42 +1451,82 @@ HttpSysServer::InitializeServer(int BasePort)
if (Result != NO_ERROR)
{
- ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result);
+ ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result);
}
}
- // Create I/O completion port
+ return true;
+}
+bool
+HttpSysServer::SetupIoCompletionPort()
+{
std::error_code ErrorCode;
m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
if (ErrorCode)
{
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message());
+ return false;
+ }
+
+ m_IsOk = true;
+
+ if (!m_BaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+ if (!m_HttpsBaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front()));
+ }
+
+ return true;
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
+ m_IsOk = false;
+ if (!CreateSessionAndUrlGroup())
+ {
return 0;
}
- else
+
+ int EffectivePort = RegisterHttpUrls(BasePort);
+
+ if (m_InitialConfig.HttpsPort > 0)
{
- m_IsOk = true;
+ if (!RegisterHttpsUrls())
+ {
+ return 0;
+ }
+ }
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ if (m_BaseUris.empty() && m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("No HTTP or HTTPS listeners could be registered");
+ return 0;
}
- // This is not available in all Windows SDK versions so for now we can't use recently
- // released functionality. We should investigate how to get more recent SDK releases
- // into the build
+ if (!CreateRequestQueue(EffectivePort))
+ {
+ return 0;
+ }
-# if 0
- if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4))
+ if (!SetupIoCompletionPort())
{
- ZEN_DEBUG("HTTP3 is available");
+ return 0;
}
- else
+
+ // When HTTPS-only, return the HTTPS port as the effective port
+ if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0)
{
- ZEN_DEBUG("HTTP3 is NOT available");
+ return m_HttpsPort;
}
-# endif
return EffectivePort;
}
@@ -1302,6 +1536,8 @@ HttpSysServer::Cleanup()
{
++m_IsShuttingDown;
+ UnbindSslCertificate();
+
if (m_RequestQueueHandle)
{
HttpCloseRequestQueue(m_RequestQueueHandle);
@@ -1321,23 +1557,122 @@ HttpSysServer::Cleanup()
}
}
+// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings
+static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}};
+
+bool
+HttpSysServer::BindSslCertificate(int Port)
+{
+ const std::string& Thumbprint = m_InitialConfig.CertThumbprint;
+ if (Thumbprint.size() != 40)
+ {
+ ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size());
+ return false;
+ }
+
+ BYTE CertHash[20] = {};
+ if (!ParseHexBytes(Thumbprint, CertHash))
+ {
+ ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters");
+ return false;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(Port));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str());
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+ SslConfig.ParamDesc.pSslHash = CertHash;
+ SslConfig.ParamDesc.SslHashLength = sizeof(CertHash);
+ SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str());
+ SslConfig.ParamDesc.AppId = ZenServerSslAppId;
+
+ ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result == ERROR_ALREADY_EXISTS)
+ {
+ // Remove existing binding and retry
+ HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {};
+ DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr);
+
+ Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+ }
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR(
+ "Failed to bind SSL certificate to port {}: {} ({:#x}). "
+ "This operation may require running as administrator.",
+ Port,
+ GetSystemErrorAsString(Result),
+ Result);
+ return false;
+ }
+
+ m_DidAutoBindCert = true;
+ m_HttpsPort = Port;
+
+ ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})",
+ Port,
+ Thumbprint.substr(0, 8),
+ m_InitialConfig.CertStoreName);
+
+ return true;
+}
+
+void
+HttpSysServer::UnbindSslCertificate()
+{
+ if (!m_DidAutoBindCert)
+ {
+ return;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result);
+ }
+ else
+ {
+ ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort);
+ }
+
+ m_DidAutoBindCert = false;
+}
+
WorkerThreadPool&
HttpSysServer::WorkPool()
{
ZEN_MEMSCOPE(GetHttpsysTag());
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_acquire))
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
- if (!m_AsyncWorkPool)
+ if (!m_AsyncWorkPool.load(std::memory_order_relaxed))
{
m_AsyncWorkPool =
new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async", m_InitialConfig.UseExplicitIoThreadPool);
}
}
- return *m_AsyncWorkPool;
+ return *m_AsyncWorkPool.load(std::memory_order_relaxed);
}
void
@@ -1449,19 +1784,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
// Convert to wide string
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
+ auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
{
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ std::wstring Url16 = BaseUri + PathUtf16;
- return;
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ return;
+ }
}
- }
+ };
+
+ RegisterWithBaseUris(m_BaseUris);
+ RegisterWithBaseUris(m_HttpsBaseUris);
}
void
@@ -1476,19 +1815,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
+ auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
}
- }
+ };
+
+ UnregisterFromBaseUris(m_BaseUris);
+ UnregisterFromBaseUris(m_HttpsBaseUris);
}
//////////////////////////////////////////////////////////////////////////
@@ -1551,7 +1893,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
// than one thread at any given moment. This means we need to be careful about what
// happens in here
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+ HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped);
+
+ switch (IoContext->ContextType)
+ {
+ case HttpSysIoContext::Type::kWebSocketRead:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kWebSocketWrite:
+ static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred);
+ return;
+
+ case HttpSysIoContext::Type::kTransaction:
+ break;
+ }
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext);
// Assign names to threads for context (only needed when using Windows' native
// thread pool)
@@ -1675,6 +2033,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
{
HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+ m_HttpServer.MarkRequest();
+
// Default request handling
# if ZEN_WITH_OTEL
@@ -1884,6 +2244,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const
}
std::string_view
+HttpSysServerRequest::GetRemoteAddress() const
+{
+ if (m_RemoteAddress.Size() == 0)
+ {
+ const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress;
+ GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false);
+ }
+ return m_RemoteAddress.ToView();
+}
+
+std::string_view
HttpSysServerRequest::GetAuthorizationHeader() const
{
const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
@@ -2111,6 +2482,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
break;
}
+ Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed);
+
ZEN_TRACE_CPU("httpsys::HandleCompletion");
// Route request
@@ -2119,64 +2492,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
{
HTTP_REQUEST* HttpReq = HttpRequest();
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
{
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
+ // WebSocket upgrade detection
+ if (m_IsInitialRequest)
{
- case HttpRequestInfoTypeRequestTiming:
+ const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade];
+ if (UpgradeHeader.RawValueLength > 0 &&
+ StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0)
+ {
+ if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service))
{
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+ // Extract Sec-WebSocket-Key from the unknown headers
+ // (http.sys has no known-header slot for it)
+ std::string_view SecWebSocketKey;
+ for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i)
+ {
+ const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i];
+ if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0)
+ {
+ SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength);
+ break;
+ }
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+ if (SecWebSocketKey.empty())
+ {
+ ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header");
+ return nullptr;
+ }
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+ const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey);
+
+ HANDLE RequestQueueHandle = Transaction().RequestQueueHandle();
+ HTTP_REQUEST_ID RequestId = HttpReq->RequestId;
+
+ // Build the 101 Switching Protocols response
+ HTTP_RESPONSE Response = {};
+ Response.StatusCode = 101;
+ Response.pReason = "Switching Protocols";
+ Response.ReasonLength = (USHORT)strlen(Response.pReason);
+
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket";
+ Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9;
+
+ eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders;
+
+ // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders
+ // despite there being an entry for it there (HttpHeaderConnection). If you try to do
+ // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below
+
+ UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"});
+
+ UnknownHeaders.push_back({.NameLength = 20,
+ .RawValueLength = (USHORT)AcceptKey.size(),
+ .pName = "Sec-WebSocket-Accept",
+ .pRawValue = AcceptKey.c_str()});
+
+ Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size();
+ Response.Headers.pUnknownHeaders = UnknownHeaders.data();
+
+ const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+
+ // Use an OVERLAPPED with an event so we can wait synchronously.
+ // The request queue is IOCP-associated, so passing NULL for pOverlapped
+ // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent
+ // prevents IOCP delivery and lets us wait on the event directly.
+ OVERLAPPED SendOverlapped = {};
+ HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+ SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1);
+
+ ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle,
+ RequestId,
+ Flags,
+ &Response,
+ nullptr, // CachePolicy
+ nullptr, // BytesSent
+ nullptr, // Reserved1
+ 0, // Reserved2
+ &SendOverlapped,
+ nullptr // LogData
+ );
+
+ if (SendResult == ERROR_IO_PENDING)
+ {
+ WaitForSingleObject(SendEvent, INFINITE);
+ SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE;
+ }
+
+ CloseHandle(SendEvent);
+
+ if (SendResult == NO_ERROR)
+ {
+ Transaction().Server().OnWebSocketConnectionOpened();
+ Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle,
+ RequestId,
+ *WsHandler,
+ Transaction().Iocp(),
+ &Transaction().Server()));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+
+ return nullptr;
+ }
+
+ ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
- ZEN_INFO("");
+ // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // was never populated (no InvokeRequestHandler call)
+ return nullptr;
}
- break;
+ // Service doesn't support WebSocket or missing key — fall through to normal handling
+ }
}
- }
-# endif
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
if (m_IsInitialRequest)
{
m_ContentLength = GetContentLength(HttpReq);
@@ -2242,6 +2673,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
}
}
+ else
+ {
+ // If a default redirect is configured and the request is for the root path, send a 302
+ std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect();
+ std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength);
+ if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty()))
+ {
+ auto* Response = new HttpMessageResponseRequest(Transaction(), 302);
+ Response->SetLocationHeader(DefaultRedirect);
+ return Response;
+ }
+ }
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
@@ -2285,6 +2728,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
+ if (m_HttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(m_HttpsPort);
+ }
+
StartServer();
return EffectivePort;
@@ -2301,6 +2749,52 @@ HttpSysServer::OnRequestExit()
m_ShutdownEvent.Set();
}
+std::string
+HttpSysServer::OnGetExternalHost() const
+{
+ // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback
+ bool IsPublic = false;
+ for (const auto& Uri : m_BaseUris)
+ {
+ if (Uri.find(L'*') != std::wstring::npos)
+ {
+ IsPublic = true;
+ break;
+ }
+ }
+
+ if (!IsPublic)
+ {
+ return "127.0.0.1";
+ }
+
+ // Use the UDP connect trick: connecting a UDP socket to an external address
+ // causes the OS to select the appropriate local interface without sending any data.
+ try
+ {
+ asio::io_context IoService;
+ asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4());
+ Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80));
+ return Sock.local_endpoint().address().to_string();
+ }
+ catch (const std::exception&)
+ {
+ return GetMachineName();
+ }
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesReceived() const
+{
+ return m_TotalBytesReceived.load(std::memory_order_relaxed);
+}
+
+uint64_t
+HttpSysServer::GetTotalBytesSent() const
+{
+ return m_TotalBytesSent.load(std::memory_order_relaxed);
+}
+
void
HttpSysServer::OnRegisterService(HttpService& Service)
{
diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h
index 4ff6df1fa..0685b42b2 100644
--- a/src/zenhttp/servers/httpsys.h
+++ b/src/zenhttp/servers/httpsys.h
@@ -23,6 +23,10 @@ struct HttpSysConfig
bool IsDedicatedServer = false;
bool ForceLoopback = false;
bool UseExplicitIoThreadPool = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
};
Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config);
diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h
new file mode 100644
index 000000000..4f8a97012
--- /dev/null
+++ b/src/zenhttp/servers/httpsys_iocontext.h
@@ -0,0 +1,40 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+
+# include <cstdint>
+
+namespace zen {
+
+/**
+ * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch
+ *
+ * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection
+ * (for WebSocket read/write) embed this struct. The single IoCompletionCallback
+ * bound to the request queue uses the ContextType tag to dispatch to the correct
+ * handler.
+ *
+ * The Overlapped member must be first so that CONTAINING_RECORD works to recover
+ * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool.
+ */
+struct HttpSysIoContext
+{
+ OVERLAPPED Overlapped{};
+
+ enum class Type : uint8_t
+ {
+ kTransaction,
+ kWebSocketRead,
+ kWebSocketWrite,
+ } ContextType = Type::kTransaction;
+
+ void* Owner = nullptr;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h
index da72c79c9..a9a45f162 100644
--- a/src/zenhttp/servers/httptracer.h
+++ b/src/zenhttp/servers/httptracer.h
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
-#include <zenhttp/httpserver.h>
-
#pragma once
+#include <zenhttp/httpserver.h>
+
namespace zen {
/** Helper class for HTTP server implementations
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
new file mode 100644
index 000000000..5ae48f5b3
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -0,0 +1,339 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsasio.h"
+#include "asio_socket_traits.h"
+#include "wsframecodec.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen::asio_http {
+
+static LoggerRef
+WsLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+: m_Socket(std::move(Socket))
+, m_Handler(Handler)
+, m_HttpServer(Server)
+{
+}
+
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::~WsAsioConnectionT()
+{
+ m_IsOpen.store(false);
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::Start()
+{
+ EnqueueRead();
+}
+
+template<typename SocketType>
+bool
+WsAsioConnectionT<SocketType>::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Read loop
+//
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::EnqueueRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ Ref<WsAsioConnectionT> Self(this);
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
+ Self->OnDataReceived(Ec, ByteCount);
+ });
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::ProcessReceivedData()
+{
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Shut down the socket
+ std::error_code ShutdownEc;
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc);
+ SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Write queue
+//
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::FlushWriteQueue()
+{
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ Ref<WsAsioConnectionT> Self(this);
+
+ // Move Frame into a shared_ptr so we can create the buffer and capture ownership
+ // in the same async_write call without evaluation order issues.
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
+}
+
+template<typename SocketType>
+void
+WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Explicit template instantiations
+
+template class WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+template class WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
new file mode 100644
index 000000000..64602ee46
--- /dev/null
+++ b/src/zenhttp/servers/wsasio.h
@@ -0,0 +1,94 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/thread.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <memory>
+#include <vector>
+
+namespace zen {
+class HttpServer;
+} // namespace zen
+
+namespace zen::asio_http {
+
+/**
+ * WebSocket connection over an ASIO stream socket
+ *
+ * Templated on SocketType to support both TCP and Unix domain sockets.
+ * Owns the socket (moved from HttpServerConnection after the 101 handshake)
+ * and runs an async read/write loop to exchange WebSocket frames.
+ *
+ * Lifetime is managed solely through intrusive reference counting (RefCounted).
+ * The async read/write callbacks capture Ref<> to keep the connection alive
+ * for the duration of the async operation. The service layer also holds a
+ * Ref<WebSocketConnection>.
+ */
+template<typename SocketType>
+class WsAsioConnectionT : public WebSocketConnection
+{
+public:
+ WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnectionT() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and the 101 response has been sent.
+ */
+ void Start();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+private:
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void ProcessReceivedData();
+
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount);
+
+ void DoClose(uint16_t Code, std::string_view Reason);
+
+ std::unique_ptr<SocketType> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
+
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
new file mode 100644
index 000000000..e452141fe
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/sha1.h>
+
+#include <cstring>
+#include <random>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+WsFrameParseResult
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+{
+ // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
+ if (Size < 2)
+ {
+ return {};
+ }
+
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ uint64_t PayloadLen = Data[1] & 0x7F;
+
+ size_t HeaderSize = 2;
+
+ if (PayloadLen == 126)
+ {
+ if (Size < 4)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]);
+ HeaderSize = 4;
+ }
+ else if (PayloadLen == 127)
+ {
+ if (Size < 10)
+ {
+ return {};
+ }
+ PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) |
+ (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]);
+ HeaderSize = 10;
+ }
+
+ // Reject frames with unreasonable payload sizes to prevent OOM
+ static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB
+ if (PayloadLen > kMaxPayloadSize)
+ {
+ return {};
+ }
+
+ const size_t MaskSize = Masked ? 4 : 0;
+ const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen;
+
+ if (Size < TotalFrame)
+ {
+ return {};
+ }
+
+ const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr;
+ const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
+
+ WsFrameParseResult Result;
+ Result.IsValid = true;
+ Result.BytesConsumed = TotalFrame;
+ Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
+ Result.Fin = Fin;
+
+ Result.Payload.resize(static_cast<size_t>(PayloadLen));
+ if (PayloadLen > 0)
+ {
+ std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen));
+
+ if (Masked)
+ {
+ for (size_t i = 0; i < Result.Payload.size(); ++i)
+ {
+ Result.Payload[i] ^= MaskKey[i & 3];
+ }
+ }
+ }
+
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (server-to-client, no masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length (no mask bit for server frames)
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ Frame.insert(Frame.end(), Payload.begin(), Payload.end());
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame building (client-to-server, with masking)
+//
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+{
+ std::vector<uint8_t> Frame;
+
+ const size_t PayloadLen = Payload.size();
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (PayloadLen < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen));
+ }
+ else if (PayloadLen <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Generate random 4-byte mask key
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ uint32_t MaskValue = s_Rng();
+ uint8_t MaskKey[4];
+ std::memcpy(MaskKey, &MaskValue, 4);
+
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < PayloadLen; ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+}
+
+std::vector<uint8_t>
+WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason)
+{
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ Payload.insert(Payload.end(), Reason.begin(), Reason.end());
+
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2)
+//
+
+static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+std::string
+WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey)
+{
+ // Concatenate client key with the magic GUID
+ std::string Combined;
+ Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size());
+ Combined.append(ClientKey);
+ Combined.append(kWebSocketMagicGuid);
+
+ // SHA1 hash
+ SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size());
+
+ // Base64 encode the 20-byte hash
+ char Base64Buf[Base64::GetEncodedDataSize(20) + 1];
+ uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf);
+ Base64Buf[EncodedLen] = '\0';
+
+ return std::string(Base64Buf, EncodedLen);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
new file mode 100644
index 000000000..2d90b6fa1
--- /dev/null
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <optional>
+#include <span>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+/**
+ * Result of attempting to parse a single WebSocket frame from a byte buffer
+ */
+struct WsFrameParseResult
+{
+ bool IsValid = false; // true if a complete frame was successfully parsed
+ size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
+ WebSocketOpcode Opcode = WebSocketOpcode::kText;
+ bool Fin = false;
+ std::vector<uint8_t> Payload;
+};
+
+/**
+ * RFC 6455 WebSocket frame codec
+ *
+ * Provides static helpers for parsing client-to-server frames (which are
+ * always masked) and building server-to-client frames (which are never masked).
+ */
+struct WsFrameCodec
+{
+ /**
+ * Try to parse one complete frame from the front of the buffer.
+ *
+ * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * there is not enough data yet. The caller should accumulate more data
+ * and retry.
+ */
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+
+ /**
+ * Build a server-to-client frame (no masking)
+ */
+ static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a close frame with a status code and optional reason string
+ */
+ static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Build a client-to-server frame (with masking per RFC 6455)
+ */
+ static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload);
+
+ /**
+ * Build a masked close frame with status code and optional reason
+ */
+ static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {});
+
+ /**
+ * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2
+ *
+ * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ */
+ static std::string ComputeAcceptKey(std::string_view ClientKey);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
new file mode 100644
index 000000000..af320172d
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -0,0 +1,485 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "wshttpsys.h"
+
+#if ZEN_WITH_HTTPSYS
+
+# include "wsframecodec.h"
+
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+namespace zen {
+
+static LoggerRef
+WsHttpSysLog()
+{
+ static LoggerRef g_Logger = logging::Get("ws_httpsys");
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle,
+ HTTP_REQUEST_ID RequestId,
+ IWebSocketHandler& Handler,
+ PTP_IO Iocp,
+ HttpServer* Server)
+: m_RequestQueueHandle(RequestQueueHandle)
+, m_RequestId(RequestId)
+, m_Handler(Handler)
+, m_Iocp(Iocp)
+, m_HttpServer(Server)
+, m_ReadBuffer(8192)
+{
+ m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead;
+ m_ReadIoContext.Owner = this;
+ m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite;
+ m_WriteIoContext.Owner = this;
+}
+
+WsHttpSysConnection::~WsHttpSysConnection()
+{
+ ZEN_ASSERT(m_OutstandingOps.load() == 0);
+
+ if (m_IsOpen.exchange(false))
+ {
+ Disconnect();
+ }
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketConnectionClosed();
+ }
+}
+
+void
+WsHttpSysConnection::Start()
+{
+ m_SelfRef = Ref<WsHttpSysConnection>(this);
+ IssueAsyncRead();
+}
+
+void
+WsHttpSysConnection::Shutdown()
+{
+ m_ShutdownRequested.store(true, std::memory_order_relaxed);
+
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+bool
+WsHttpSysConnection::IsOpen() const
+{
+ return m_IsOpen.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async read path
+//
+
+void
+WsHttpSysConnection::IssueAsyncRead()
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed))
+ {
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ 0, // Flags
+ m_ReadBuffer.data(),
+ (ULONG)m_ReadBuffer.size(),
+ nullptr, // BytesRead (ignored for async)
+ &m_ReadIoContext.Overlapped);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "read issue failed");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef
+ Ref<WsHttpSysConnection> Guard(this);
+
+ if (IoResult != NO_ERROR)
+ {
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.exchange(false))
+ {
+ if (IoResult == ERROR_HANDLE_EOF)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection closed");
+ }
+ else if (IoResult != ERROR_OPERATION_ABORTED)
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "connection lost");
+ }
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ if (NumberOfBytesTransferred > 0)
+ {
+ m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred);
+ ProcessReceivedData();
+ }
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ IssueAsyncRead();
+ }
+ else
+ {
+ MaybeReleaseSelfRef();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Frame parsing
+//
+
+void
+WsHttpSysConnection::ProcessReceivedData()
+{
+ while (!m_Accumulated.empty())
+ {
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
+ if (!Frame.IsValid)
+ {
+ break; // not enough data yet
+ }
+
+ // Remove consumed bytes
+ m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
+
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
+ }
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with pong carrying the same payload
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ // Unsolicited pong — ignore per RFC 6455
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo close frame back if we haven't sent one yet
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+ Disconnect();
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Async write path
+//
+
+void
+WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+{
+ if (m_HttpServer)
+ {
+ m_HttpServer->OnWebSocketFrameSent(Frame.size());
+ }
+
+ bool ShouldFlush = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.push_back(std::move(Frame));
+
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ }
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+}
+
+void
+WsHttpSysConnection::FlushWriteQueue()
+{
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+
+ m_CurrentWriteBuffer = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ }
+
+ m_OutstandingOps.fetch_add(1, std::memory_order_relaxed);
+
+ ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk));
+ m_WriteChunk.DataChunkType = HttpDataChunkFromMemory;
+ m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data();
+ m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size();
+
+ ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED));
+
+ StartThreadpoolIo(m_Iocp);
+
+ ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_MORE_DATA,
+ 1,
+ &m_WriteChunk,
+ nullptr,
+ nullptr,
+ 0,
+ &m_WriteIoContext.Overlapped,
+ nullptr);
+
+ if (Result != NO_ERROR && Result != ERROR_IO_PENDING)
+ {
+ CancelThreadpoolIo(m_Iocp);
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+ m_CurrentWriteBuffer.clear();
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ }
+}
+
+void
+WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ // Hold a transient ref to prevent mid-callback destruction
+ Ref<WsHttpSysConnection> Guard(this);
+
+ m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed);
+ m_CurrentWriteBuffer.clear();
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult);
+
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ m_WriteQueue.clear();
+ m_IsWriting = false;
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWebSocketClose(*this, 1006, "write error");
+ }
+
+ MaybeReleaseSelfRef();
+ return;
+ }
+
+ FlushWriteQueue();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Send interface
+//
+
+void
+WsHttpSysConnection::SendText(std::string_view Text)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data)
+{
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+}
+
+void
+WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason)
+{
+ DoClose(Code, Reason);
+}
+
+void
+WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
+{
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ {
+ bool ShouldSendClose = false;
+ {
+ RwLock::ExclusiveLockScope _(m_WriteLock);
+ if (!m_CloseSent.exchange(true))
+ {
+ ShouldSendClose = true;
+ }
+ }
+ if (ShouldSendClose)
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ m_Handler.OnWebSocketClose(*this, Code, Reason);
+
+ // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Lifetime management
+//
+
+void
+WsHttpSysConnection::MaybeReleaseSelfRef()
+{
+ if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ m_SelfRef = nullptr;
+ }
+}
+
+void
+WsHttpSysConnection::Disconnect()
+{
+ // Send final empty body with DISCONNECT to tell http.sys the connection is done
+ HttpSendResponseEntityBody(m_RequestQueueHandle,
+ m_RequestId,
+ HTTP_SEND_RESPONSE_FLAG_DISCONNECT,
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ 0,
+ nullptr,
+ nullptr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
new file mode 100644
index 000000000..6015e3873
--- /dev/null
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/websocket.h>
+
+#include "httpsys_iocontext.h"
+
+#include <zencore/thread.h>
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <http.h>
+
+# include <atomic>
+# include <deque>
+# include <vector>
+
+namespace zen {
+
+class HttpServer;
+
+/**
+ * WebSocket connection over an http.sys opaque-mode connection
+ *
+ * After the 101 Switching Protocols response is sent with
+ * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the
+ * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody /
+ * HttpSendResponseEntityBody using the original RequestId.
+ *
+ * All I/O is performed asynchronously via the same IOCP threadpool used
+ * for normal http.sys traffic, eliminating per-connection threads.
+ *
+ * Lifetime is managed through intrusive reference counting (RefCounted).
+ * A self-reference (m_SelfRef) is held from Start() until all outstanding
+ * async operations have drained, preventing premature destruction.
+ */
+class WsHttpSysConnection : public WebSocketConnection
+{
+public:
+ WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server);
+ ~WsHttpSysConnection() override;
+
+ /**
+ * Start the async read loop. Must be called once after construction
+ * and after the 101 response has been sent.
+ */
+ void Start();
+
+ /**
+ * Shut down the connection. Cancels pending I/O; IOCP completions
+ * will fire with ERROR_OPERATION_ABORTED and drain naturally.
+ */
+ void Shutdown();
+
+ // WebSocketConnection interface
+ void SendText(std::string_view Text) override;
+ void SendBinary(std::span<const uint8_t> Data) override;
+ void Close(uint16_t Code, std::string_view Reason) override;
+ bool IsOpen() const override;
+
+ // Called from IoCompletionCallback via tagged dispatch
+ void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+ void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+private:
+ void IssueAsyncRead();
+ void ProcessReceivedData();
+ void EnqueueWrite(std::vector<uint8_t> Frame);
+ void FlushWriteQueue();
+ void DoClose(uint16_t Code, std::string_view Reason);
+ void Disconnect();
+ void MaybeReleaseSelfRef();
+
+ HANDLE m_RequestQueueHandle;
+ HTTP_REQUEST_ID m_RequestId;
+ IWebSocketHandler& m_Handler;
+ PTP_IO m_Iocp;
+ HttpServer* m_HttpServer;
+
+ // Tagged OVERLAPPED contexts for concurrent read and write
+ HttpSysIoContext m_ReadIoContext{};
+ HttpSysIoContext m_WriteIoContext{};
+
+ // Read state
+ std::vector<uint8_t> m_ReadBuffer;
+ std::vector<uint8_t> m_Accumulated;
+
+ // Write state
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ std::vector<uint8_t> m_CurrentWriteBuffer;
+ HTTP_DATA_CHUNK m_WriteChunk{};
+ bool m_IsWriting = false;
+
+ // Lifetime management
+ std::atomic<int32_t> m_OutstandingOps{0};
+ Ref<WsHttpSysConnection> m_SelfRef;
+ std::atomic<bool> m_ShutdownRequested{false};
+ std::atomic<bool> m_IsOpen{true};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_HTTPSYS
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
new file mode 100644
index 000000000..59c46a418
--- /dev/null
+++ b/src/zenhttp/servers/wstest.cpp
@@ -0,0 +1,994 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpwsclient.h>
+# include <zenhttp/websocket.h>
+
+# include "httpasio.h"
+# include "wsframecodec.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_PLATFORM_WINDOWS
+# include <winsock2.h>
+# else
+# include <poll.h>
+# include <sys/socket.h>
+# endif
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+# include <atomic>
+# include <chrono>
+# include <cstring>
+# include <random>
+# include <string>
+# include <string_view>
+# include <thread>
+# include <vector>
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests: WsFrameCodec
+//
+
+TEST_SUITE_BEGIN("http.wstest");
+
+TEST_CASE("websocket.framecodec")
+{
+ SUBCASE("ComputeAcceptKey RFC 6455 test vector")
+ {
+ // RFC 6455 section 4.2.2 example
+ std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ // Server frames are unmasked — TryParseFrame should handle them
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildFrame and TryParseFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure");
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+
+ SUBCASE("TryParseFrame - partial data returns invalid")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ // Pass only 1 byte — not enough for a frame header
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
+ CHECK_FALSE(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, 0u);
+ }
+
+ SUBCASE("TryParseFrame - empty payload")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK(Result.Payload.empty());
+ }
+
+ SUBCASE("TryParseFrame - masked client frame")
+ {
+ // Build a masked frame manually as a client would send
+ // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello"
+ uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D};
+ uint8_t MaskedPayload[5] = {};
+ const char* Original = "Hello";
+ for (int i = 0; i < 5; ++i)
+ {
+ MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4];
+ }
+
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x81); // FIN + text
+ Frame.push_back(0x85); // MASK + len=5
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), 5u);
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - text")
+ {
+ std::string_view Text = "Hello, masked WebSocket!";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ // Verify mask bit is set
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.BytesConsumed, Frame.size());
+ CHECK(Result.Fin);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kText);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text);
+ }
+
+ SUBCASE("BuildMaskedFrame roundtrip - binary")
+ {
+ std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD};
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData);
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary);
+ CHECK_EQ(Result.Payload, BinaryData);
+ }
+
+ SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(300, 0x42);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 300u);
+ CHECK_EQ(Result.Payload, Payload);
+ }
+
+ SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)")
+ {
+ std::vector<uint8_t> Payload(70000, 0xAB);
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload);
+
+ CHECK((Frame[1] & 0x80) != 0);
+ CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), 70000u);
+ }
+
+ SUBCASE("BuildMaskedCloseFrame roundtrip")
+ {
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
+
+ CHECK((Frame[1] & 0x80) != 0);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
+
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Result.Payload.size() >= 2);
+
+ uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]);
+ CHECK_EQ(Code, 1000);
+
+ std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2);
+ CHECK_EQ(Reason, "normal closure");
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: WebSocket over ASIO
+//
+
+namespace {
+
+ /**
+ * Helper: Build a masked client-to-server frame per RFC 6455
+ */
+ std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN + opcode
+ Frame.push_back(0x80 | static_cast<uint8_t>(Opcode));
+
+ // Payload length with mask bit set
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ // Mask key (use a fixed key for deterministic tests)
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ // Masked payload
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
+ std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
+ {
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ return BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ }
+
+ std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code)
+ {
+ std::vector<uint8_t> Payload;
+ Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF));
+ Payload.push_back(static_cast<uint8_t>(Code & 0xFF));
+ return BuildMaskedFrame(WebSocketOpcode::kClose, Payload);
+ }
+
+ /**
+ * Test service that implements IWebSocketHandler
+ */
+ struct WsTestService : public HttpService, public IWebSocketHandler
+ {
+ const char* BaseUri() const override { return "/wstest/"; }
+
+ void HandleRequest(HttpServerRequest& Request) override
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest");
+ }
+
+ // IWebSocketHandler
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ {
+ m_OpenCount.fetch_add(1);
+
+ m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
+ }
+
+ void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override
+ {
+ m_MessageCount.fetch_add(1);
+
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+
+ // Echo the message back
+ Conn.SendText(Text);
+ }
+ }
+
+ void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+
+ m_ConnectionsLock.WithExclusiveLock([&] {
+ auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) {
+ return C.Get() == &Conn;
+ });
+ m_Connections.erase(It, m_Connections.end());
+ });
+ }
+
+ void SendToAll(std::string_view Text)
+ {
+ RwLock::SharedLockScope _(m_ConnectionsLock);
+ for (auto& Conn : m_Connections)
+ {
+ if (Conn->IsOpen())
+ {
+ Conn->SendText(Text);
+ }
+ }
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+
+ RwLock m_ConnectionsLock;
+ std::vector<Ref<WebSocketConnection>> m_Connections;
+ };
+
+ /**
+ * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket
+ *
+ * Returns true on success (101 response), false otherwise.
+ */
+ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port)
+ {
+ // Send HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << Path << " HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ // Read the response (look for "101")
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ return Response.find("101") != std::string::npos;
+ }
+
+ /**
+ * Helper: Read a single server-to-client frame from a socket
+ *
+ * Uses a background thread with a synchronous ASIO read and a timeout.
+ */
+ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000)
+ {
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Result;
+ std::atomic<bool> Done{false};
+
+ std::thread Reader([&] {
+ while (!Done.load())
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (Frame.IsValid)
+ {
+ Result = std::move(Frame);
+ Done.store(true);
+ return;
+ }
+ }
+ });
+
+ auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs);
+ while (!Done.load() && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ if (!Done.load())
+ {
+ // Timeout — cancel the read
+ asio::error_code Ec;
+ Sock.cancel(Ec);
+ }
+
+ if (Reader.joinable())
+ {
+ Reader.join();
+ }
+
+ return Result;
+ }
+
+} // anonymous namespace
+
+TEST_CASE("websocket.integration")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ // Give server a moment to start accepting
+ Sleep(100);
+
+ SUBCASE("handshake succeeds with 101")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ CHECK(Ok);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_OpenCount.load(), 1);
+
+ Sock.close();
+ }
+
+ SUBCASE("normal HTTP still works alongside WebSocket service")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Send a normal HTTP GET (not upgrade)
+ std::string HttpReq = fmt::format(
+ "GET /wstest/hello HTTP/1.1\r\n"
+ "Host: 127.0.0.1:{}\r\n"
+ "Connection: close\r\n"
+ "\r\n",
+ Port);
+
+ asio::write(Sock, asio::buffer(HttpReq));
+
+ asio::streambuf ResponseBuf;
+ asio::error_code Ec;
+ asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec);
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+ CHECK(Response.find("200") != std::string::npos);
+ }
+
+ SUBCASE("echo message roundtrip")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a text message (masked, as client)
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test");
+ asio::write(Sock, asio::buffer(Frame));
+
+ // Read the echo reply
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "ping test"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "ping test");
+
+ Sock.close();
+ }
+
+ SUBCASE("server push to client")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Server pushes a message
+ TestService.SendToAll("server says hello");
+
+ WsFrameParseResult Frame = ReadOneFrame(Sock);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "server says hello"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("client close handshake")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send close frame
+ std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000);
+ asio::write(Sock, asio::buffer(CloseFrame));
+
+ // Server should echo close back
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+
+ Sleep(50);
+ CHECK_EQ(TestService.m_CloseCount.load(), 1);
+ CHECK_EQ(TestService.m_LastCloseCode.load(), 1000);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple concurrent connections")
+ {
+ constexpr int NumClients = 5;
+
+ asio::io_context IoCtx;
+ std::vector<asio::ip::tcp::socket> Sockets;
+
+ for (int i = 0; i < NumClients; ++i)
+ {
+ Sockets.emplace_back(IoCtx);
+ Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port);
+ REQUIRE(Ok);
+ }
+
+ Sleep(100);
+ CHECK_EQ(TestService.m_OpenCount.load(), NumClients);
+
+ // Broadcast from server
+ TestService.SendToAll("broadcast");
+
+ // Each client should receive the message
+ for (int i = 0; i < NumClients; ++i)
+ {
+ WsFrameParseResult Frame = ReadOneFrame(Sockets[i]);
+ REQUIRE(Frame.IsValid);
+ CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText);
+ std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size());
+ CHECK_EQ(Text, "broadcast"sv);
+ }
+
+ // Close all
+ for (auto& S : Sockets)
+ {
+ S.close();
+ }
+ }
+
+ SUBCASE("service without IWebSocketHandler rejects upgrade")
+ {
+ // Register a plain HTTP service (no WebSocket)
+ struct PlainService : public HttpService
+ {
+ const char* BaseUri() const override { return "/plain/"; }
+ void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); }
+ };
+
+ PlainService Plain;
+ Server->RegisterService(Plain);
+
+ Sleep(50);
+
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ // Attempt WebSocket upgrade on the plain service
+ ExtendableStringBuilder<512> Request;
+ Request << "GET /plain/ws HTTP/1.1\r\n"
+ << "Host: 127.0.0.1:" << Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ << "Sec-WebSocket-Version: 13\r\n"
+ << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+ asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size()));
+
+ asio::streambuf ResponseBuf;
+ asio::read_until(Sock, ResponseBuf, "\r\n\r\n");
+
+ std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
+
+ // Should NOT get 101 — should fall through to normal request handling
+ CHECK(Response.find("101") == std::string::npos);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping/pong auto-response")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send a ping frame with payload "test"
+ std::string_view PingPayload = "test";
+ std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size());
+ std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData);
+ asio::write(Sock, asio::buffer(PingFrame));
+
+ // Should receive a pong with the same payload
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong);
+ CHECK_EQ(Reply.Payload.size(), 4u);
+ std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(PongText, "test"sv);
+
+ Sock.close();
+ }
+
+ SUBCASE("multiple messages in sequence")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ for (int i = 0; i < 10; ++i)
+ {
+ std::string Msg = fmt::format("message {}", i);
+ std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg);
+ asio::write(Sock, asio::buffer(Frame));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, Msg);
+ }
+
+ CHECK_EQ(TestService.m_MessageCount.load(), 10);
+
+ Sock.close();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Integration tests: HttpWsClient
+//
+
+namespace {
+
+ struct TestWsClientHandler : public IWsClientHandler
+ {
+ void OnWsOpen() override { m_OpenCount.fetch_add(1); }
+
+ void OnWsMessage(const WebSocketMessage& Msg) override
+ {
+ if (Msg.Opcode == WebSocketOpcode::kText)
+ {
+ std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size());
+ m_LastMessage = std::string(Text);
+ }
+ m_MessageCount.fetch_add(1);
+ }
+
+ void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override
+ {
+ m_CloseCount.fetch_add(1);
+ m_LastCloseCode = Code;
+ }
+
+ std::atomic<int> m_OpenCount{0};
+ std::atomic<int> m_MessageCount{0};
+ std::atomic<int> m_CloseCount{0};
+ std::atomic<uint16_t> m_LastCloseCode{0};
+ std::string m_LastMessage;
+ };
+
+} // anonymous namespace
+
+TEST_CASE("websocket.client")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello from client");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello from client");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ // The server echoes the close frame, which triggers OnWsClose on the client side
+ // with the server's close code. Allow the connection to settle.
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+
+ SUBCASE("connect to bad port")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = "ws://127.0.0.1:1/wstest/ws";
+
+ HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1006);
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ }
+
+ SUBCASE("server-initiated close")
+ {
+ TestWsClientHandler Handler;
+ std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port);
+
+ HttpWsClient Client(Url, Handler);
+ Client.Connect();
+
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+
+ // Copy connections then close them outside the lock to avoid deadlocking
+ // with OnWebSocketClose which acquires an exclusive lock
+ std::vector<Ref<WebSocketConnection>> Conns;
+ TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; });
+ for (auto& Conn : Conns)
+ {
+ Conn->Close(1001, "going away");
+ }
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ CHECK_EQ(Handler.m_LastCloseCode.load(), 1001);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+TEST_CASE("websocket.client.unixsocket")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "ws.sock").string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close over unix socket")
+ {
+ TestWsClientHandler Handler;
+ HttpWsClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello over unix socket");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello over unix socket");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
+TEST_SUITE_END();
+
+void
+websocket_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_TESTS
diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp
index 23ac1bc8b..d5413b9af 100644
--- a/src/zenhttp/transports/asiotransport.cpp
+++ b/src/zenhttp/transports/asiotransport.cpp
@@ -47,10 +47,10 @@ private:
uint16_t m_BasePort = 8558;
int m_ThreadCount = 0;
- asio::io_service m_IoService;
- asio::io_service::work m_Work{m_IoService};
- std::unique_ptr<AsioTransportAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
+ asio::io_context m_IoService;
+ asio::executor_work_guard<asio::io_context::executor_type> m_Work{m_IoService.get_executor()};
+ std::unique_ptr<AsioTransportAcceptor> m_Acceptor;
+ std::vector<std::thread> m_ThreadPool;
};
struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this<AsioTransportConnection>
@@ -85,7 +85,7 @@ private:
struct AsioTransportAcceptor
{
- AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort);
+ AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_context& IoService, uint16_t BasePort);
~AsioTransportAcceptor();
void Start();
@@ -95,7 +95,7 @@ struct AsioTransportAcceptor
private:
TransportServer* m_ServerInterface = nullptr;
- asio::io_service& m_IoService;
+ asio::io_context& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
std::atomic<bool> m_IsStopped{false};
@@ -104,7 +104,7 @@ private:
//////////////////////////////////////////////////////////////////////////
-AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort)
+AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_context& IoService, uint16_t BasePort)
: m_ServerInterface(ServerInterface)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp
index 9135d5425..489324aba 100644
--- a/src/zenhttp/transports/dlltransport.cpp
+++ b/src/zenhttp/transports/dlltransport.cpp
@@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa
void
DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message)
{
- logging::level::LogLevel Level;
- // clang-format off
switch (PluginLogLevel)
{
- case LogLevel::Trace: Level = logging::level::Trace; break;
- case LogLevel::Debug: Level = logging::level::Debug; break;
- case LogLevel::Info: Level = logging::level::Info; break;
- case LogLevel::Warn: Level = logging::level::Warn; break;
- case LogLevel::Err: Level = logging::level::Err; break;
- case LogLevel::Critical: Level = logging::level::Critical; break;
- default: Level = logging::level::Off; break;
+ case LogLevel::Trace:
+ ZEN_TRACE("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Debug:
+ ZEN_DEBUG("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Info:
+ ZEN_INFO("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Warn:
+ ZEN_WARN("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Err:
+ ZEN_ERROR("[{}] {}", m_PluginName, Message);
+ return;
+
+ case LogLevel::Critical:
+ ZEN_CRITICAL("[{}] {}", m_PluginName, Message);
+ return;
+
+ default:
+ ZEN_UNUSED(Message);
+ break;
}
- // clang-format on
- ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message)
}
uint32_t
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 78876d21b..b4c65ea96 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -6,11 +6,22 @@ target('zenhttp')
add_headerfiles("**.h")
add_files("**.cpp")
add_files("servers/httpsys.cpp", {unity_ignored=true})
+ add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
- add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr")
+ add_deps("zencore", "zentelemetry", "transport-sdk", "asio")
+ if has_config("zencpr") then
+ add_deps("cpr")
+ else
+ remove_files("clients/httpclientcpr.cpp")
+ end
add_packages("http_parser", "json11")
add_options("httpsys")
+ if is_plat("linux", "macosx") then
+ add_packages("openssl3")
+ end
+
if is_plat("linux") then
add_syslinks("dl") -- TODO: is libdl needed?
end
+
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index 0b5408453..3ac8eea8d 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -16,8 +16,10 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ httpclient_test_forcelink();
forcelink_packageformat();
passwordsecurity_forcelink();
+ websocket_forcelink();
}
} // namespace zen