aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-10 17:27:26 +0100
committerGitHub Enterprise <[email protected]>2026-03-10 17:27:26 +0100
commitd0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch)
tree2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src/zenhttp
parentchangelog entry which was inadvertently omitted from PR merge (diff)
downloadzen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz
zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency. ### HttpClient Backend Overhaul - Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines) as an alternative to the cpr-based one - Made HttpClient backend **configurable at runtime** via constructor arguments and `-httpclient=...` CLI option (for zen, zenserver, and tests) - Extended HttpClient test suite to cover multipart/content-range scenarios ### Unix Domain Socket Support - Added Unix domain socket support to **httpasio** (server side) - Added Unix domain socket support to **HttpClient** - Added Unix domain socket support to **HttpWsClient** (WebSocket client) - Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>` to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch ### HTTPS Support - Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL) - Added **basic HTTPS support for http.sys** (Windows) - Implemented HTTPS test for httpasio - Split `InitializeServer` into smaller sub-functions for http.sys ### Other Notable Changes - Improved **zenhttp-test stability** with dynamic port allocation - Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED) - Fatal signal/exception handlers for backtrace generation in tests - Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp57
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp135
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp1947
-rw-r--r--src/zenhttp/clients/httpclientcurl.h135
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp213
-rw-r--r--src/zenhttp/httpclient.cpp121
-rw-r--r--src/zenhttp/httpclient_test.cpp299
-rw-r--r--src/zenhttp/httpserver.cpp23
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h2
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h74
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h15
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h4
-rw-r--r--src/zenhttp/servers/asio_socket_traits.h54
-rw-r--r--src/zenhttp/servers/httpasio.cpp687
-rw-r--r--src/zenhttp/servers/httpasio.h6
-rw-r--r--src/zenhttp/servers/httpsys.cpp409
-rw-r--r--src/zenhttp/servers/httpsys.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp64
-rw-r--r--src/zenhttp/servers/wsasio.h43
-rw-r--r--src/zenhttp/servers/wstest.cpp73
-rw-r--r--src/zenhttp/xmake.lua5
21 files changed, 3921 insertions, 449 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 6f4c67dd0..e4d11547a 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -646,6 +646,63 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("ParseContentRange")
+{
+ SUBCASE("normal range with total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500");
+ CHECK_EQ(Offset, 0);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("non-zero offset")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878");
+ CHECK_EQ(Offset, 2638);
+ CHECK_EQ(Length, 5111437 - 2638 + 1);
+ }
+
+ SUBCASE("wildcard total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*");
+ CHECK_EQ(Offset, 100);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("no slash (total size omitted)")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 50-149");
+ CHECK_EQ(Offset, 50);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("malformed input returns zeros")
+ {
+ auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500");
+ CHECK_EQ(Offset1, 0);
+ CHECK_EQ(Length1, 0);
+
+ auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500");
+ CHECK_EQ(Offset2, 0);
+ CHECK_EQ(Length2, 0);
+
+ auto [Offset3, Length3] = detail::ParseContentRange("");
+ CHECK_EQ(Offset3, 0);
+ CHECK_EQ(Length3, 0);
+
+ auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500");
+ CHECK_EQ(Offset4, 0);
+ CHECK_EQ(Length4, 0);
+ }
+
+ SUBCASE("single byte range")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000");
+ CHECK_EQ(Offset, 42);
+ CHECK_EQ(Length, 1);
+ }
+}
+
TEST_CASE("MultipartBoundaryParser")
{
uint64_t Range1Offset = 2638;
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 14e40b02a..f3082e0a2 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -14,6 +14,11 @@
#include <zenhttp/packageformat.h>
#include <algorithm>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/ssl_options.h>
+#include <cpr/unix_socket.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
namespace zen {
HttpClientBase*
@@ -24,84 +29,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
-bool
-HttpClient::ErrorContext::IsConnectionError() const
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCprError(cpr::ErrorCode Code)
{
- switch (static_cast<cpr::ErrorCode>(ErrorCode))
+ switch (Code)
{
+ case cpr::ErrorCode::OK:
+ return HttpClientErrorCode::kOK;
case cpr::ErrorCode::CONNECTION_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kConnectionFailure;
case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kHostResolutionFailure;
case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return true;
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ return HttpClientErrorCode::kInternalError;
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ return HttpClientErrorCode::kSSLCACertError;
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return HttpClientErrorCode::kGenericSSLError;
+ case cpr::ErrorCode::REQUEST_CANCELLED:
+ return HttpClientErrorCode::kRequestCancelled;
default:
- return false;
- }
-}
-
-// If we want to support different HTTP client implementations then we'll need to make this more abstract
-
-HttpClientError::ResponseClass
-HttpClientError::GetResponseClass() const
-{
- if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
- {
- switch ((cpr::ErrorCode)m_Error)
- {
- case cpr::ErrorCode::CONNECTION_FAILURE:
- return ResponseClass::kHttpCantConnectError;
- case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
- case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return ResponseClass::kHttpNoHost;
- case cpr::ErrorCode::INTERNAL_ERROR:
- case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
- case cpr::ErrorCode::NETWORK_SEND_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
- return ResponseClass::kHttpTimeout;
- case cpr::ErrorCode::SSL_CONNECT_ERROR:
- case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_CACERT_ERROR:
- case cpr::ErrorCode::GENERIC_SSL_ERROR:
- return ResponseClass::kHttpSLLError;
- default:
- return ResponseClass::kHttpOtherClientError;
- }
- }
- else if (IsHttpSuccessCode(m_ResponseCode))
- {
- return ResponseClass::kSuccess;
- }
- else
- {
- switch (m_ResponseCode)
- {
- case HttpResponseCode::Unauthorized:
- return ResponseClass::kHttpUnauthorized;
- case HttpResponseCode::NotFound:
- return ResponseClass::kHttpNotFound;
- case HttpResponseCode::Forbidden:
- return ResponseClass::kHttpForbidden;
- case HttpResponseCode::Conflict:
- return ResponseClass::kHttpConflict;
- case HttpResponseCode::InternalServerError:
- return ResponseClass::kHttpInternalServerError;
- case HttpResponseCode::ServiceUnavailable:
- return ResponseClass::kHttpServiceUnavailable;
- case HttpResponseCode::BadGateway:
- return ResponseClass::kHttpBadGateway;
- case HttpResponseCode::GatewayTimeout:
- return ResponseClass::kHttpGatewayTimeout;
- default:
- if (m_ResponseCode >= HttpResponseCode::InternalServerError)
- {
- return ResponseClass::kHttpOtherServerError;
- }
- else
- {
- return ResponseClass::kHttpOtherClientError;
- }
- }
+ return HttpClientErrorCode::kOtherError;
}
}
@@ -257,8 +220,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId,
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
.ElapsedSeconds = HttpResponse.elapsed,
- .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
- .ErrorMessage = HttpResponse.error.message}};
+ .Error =
+ HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}};
}
if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
@@ -526,6 +489,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
}
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ CprSession->UpdateHeader({{"Connection", "close"}});
+ }
if (AccessToken)
{
CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
@@ -544,6 +511,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
CprSession->SetParameters({});
}
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath));
+ }
+
+ if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty())
+ {
+ cpr::SslOptions SslOpts;
+ if (ConnectionSettings.InsecureSsl)
+ {
+ SslOpts.SetOption(cpr::ssl::VerifyHost{false});
+ SslOpts.SetOption(cpr::ssl::VerifyPeer{false});
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath});
+ }
+ CprSession->SetSslOptions(SslOpts);
+ }
+
ExtendableStringBuilder<128> UrlBuffer;
UrlBuffer << BaseUrl << ResourcePath;
CprSession->SetUrl(UrlBuffer.c_str());
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
new file mode 100644
index 000000000..3cb749018
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -0,0 +1,1947 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcurl.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+namespace zen {
+
+HttpClientBase*
+CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
+{
+ return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+}
+
+static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback helpers
+
+struct WriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<WriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct HeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+static size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<HeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+
+ // Trim trailing \r\n
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ // Trim whitespace
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct ReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<ReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+struct StreamReadCallbackData
+{
+ detail::CompositeBufferReadStream* Reader = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<StreamReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ return Data->Reader->Read(Buffer, MaxRead);
+}
+
+struct FileReadCallbackData
+{
+ detail::BufferedReadFileStream* Buffer = nullptr;
+ uint64_t TotalSize = 0;
+ uint64_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<FileReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->TotalSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ Data->Buffer->Read(Buffer, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+static int
+CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr)
+{
+ ZEN_UNUSED(Handle);
+ LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr);
+ auto Log = [&]() -> LoggerRef { return LogRef; };
+
+ std::string_view DataView(Data, Size);
+
+ // Trim trailing newlines
+ while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n'))
+ {
+ DataView.remove_suffix(1);
+ }
+
+ switch (Type)
+ {
+ case CURLINFO_TEXT:
+ if (DataView.find("need more data"sv) == std::string_view::npos)
+ {
+ ZEN_INFO("TEXT: {}", DataView);
+ }
+ break;
+ case CURLINFO_HEADER_IN:
+ ZEN_INFO("HIN : {}", DataView);
+ break;
+ case CURLINFO_HEADER_OUT:
+ if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos)
+ {
+ std::string Copy(DataView);
+ auto BearerStart = TokenPos + 22;
+ auto BearerEnd = Copy.find_first_of("\r\n", BearerStart);
+ if (BearerEnd == std::string::npos)
+ {
+ BearerEnd = Copy.length();
+ }
+ Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart));
+ ZEN_INFO("HOUT: {}", Copy);
+ }
+ else
+ {
+ ZEN_INFO("HOUT: {}", DataView);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+static curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ std::string HeaderLine = fmt::format("{}: {}", Key, Value);
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ std::string SessionHeader = fmt::format("UE-Session: {}", SessionId);
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken)
+ {
+ std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value);
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ std::string HeaderLine = fmt::format("{}: {}", Key, Value);
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+static std::string
+BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters)
+{
+ std::string Url;
+ Url.reserve(BaseUrl.size() + ResourcePath.size() + 64);
+ Url.append(BaseUrl);
+ Url.append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size()));
+ char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size()));
+ Url += Separator;
+ Url += EncodedKey;
+ Url += '=';
+ Url += EncodedValue;
+ curl_free(EncodedKey);
+ curl_free(EncodedValue);
+ Separator = '&';
+ }
+ }
+
+ return Url;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction)
+: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction))
+{
+}
+
+CurlHttpClient::~CurlHttpClient()
+{
+ ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto* Handle : m_Sessions)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_Sessions.clear();
+ });
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::Perform()
+{
+ CurlResult Result;
+
+ char ErrorBuffer[CURL_ERROR_SIZE] = {};
+ curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer);
+
+ Result.ErrorCode = curl_easy_perform(Handle);
+
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode);
+ }
+
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+ Result.ElapsedSeconds = Elapsed;
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+ Result.UploadedBytes = static_cast<int64_t>(UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+ Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+
+ return Result;
+}
+
+bool
+CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
+HttpClient::Response
+CurlHttpClient::ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size());
+
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Type")
+ {
+ const HttpContentType ContentType = ParseContentType(Value);
+ ResponseBuffer.SetContentType(ContentType);
+ break;
+ }
+ }
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ if (ShouldLogErrorCode(WorkResponseCode))
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}",
+ SessionId,
+ static_cast<int>(WorkResponseCode),
+ m_BaseUri);
+ }
+ }
+
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Ranges = std::move(BoundaryPositions)};
+}
+
+HttpClient::Response
+CurlHttpClient::CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode);
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
+ if (!Quiet)
+ {
+ if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
+ Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ SessionId,
+ static_cast<int>(Result.ErrorCode),
+ Result.ErrorMessage);
+ }
+ }
+
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()),
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload))
+ {
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ else
+ {
+ return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
+ }
+}
+
+bool
+CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+
+ IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+
+ // Find Content-Length in headers
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Length")
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value);
+ if (!ExpectedContentSize.has_value())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value);
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value);
+ return false;
+ }
+ break;
+ }
+ }
+
+ if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent))
+ {
+ return true;
+ }
+
+ // Check X-Jupiter-IoHash
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "X-Jupiter-IoHash")
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(Value, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString());
+ return false;
+ }
+ }
+ break;
+ }
+ }
+
+ // Validate content-type specific payload
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Type")
+ {
+ if (Value == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer,
+ RawHash,
+ RawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr))
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = "Compressed binary failed validation";
+ return false;
+ }
+ }
+ if (Value == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error));
+ return false;
+ }
+ }
+ break;
+ }
+ }
+
+ return true;
+}
+
+bool
+CurlHttpClient::ShouldRetry(const CurlResult& Result)
+{
+ switch (Result.ErrorCode)
+ {
+ case CURLE_OK:
+ break;
+ case CURLE_RECV_ERROR:
+ case CURLE_SEND_ERROR:
+ case CURLE_OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch (static_cast<HttpResponseCode>(Result.StatusCode))
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ ZEN_INFO("{} Attempt {}/{}",
+ CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (ValidatePayload(Result, PayloadFile))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ ZEN_INFO("{} Attempt {}/{}",
+ CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Session
+CurlHttpClient::AllocSession(std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken)
+{
+ ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader);
+ ZEN_TRACE_CPU("CurlHttpClient::AllocSession");
+ CURL* Handle = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ Handle = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (Handle == nullptr)
+ {
+ Handle = curl_easy_init();
+ }
+ else
+ {
+ curl_easy_reset(Handle);
+ }
+
+ // Unix domain socket
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str());
+ }
+
+ // Build URL with parameters
+ std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Timeouts
+ if (ConnectionSettings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count()));
+ }
+ if (ConnectionSettings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (ConnectionSettings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // Verbose/debug
+ if (ConnectionSettings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log);
+ }
+
+ // SSL options
+ if (ConnectionSettings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str());
+ }
+
+ // Disable signal handling for thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+
+ // Note: Headers are NOT set here. Each method builds its own header list
+ // (potentially adding method-specific headers like Content-Type) and is
+ // responsible for freeing it with curl_slist_free_all.
+
+ return Session(this, Handle);
+}
+
+void
+CurlHttpClient::ReleaseSession(CURL* Handle)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession");
+
+ // Free any header list that was set
+ // curl_easy_reset will be called on next AllocSession, which cleans up the handle state.
+ // We just push the handle back to the pool.
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Response
+CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::TransactPackage");
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++CurlHttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ std::vector<std::pair<std::string, std::string>> OfferExtraHeaders;
+ OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer));
+ OfferExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size()));
+
+ std::string FilterBody;
+ WriteCallbackData WriteData{.Body = &FilterBody};
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+
+ CurlResult Result = Sess.Perform();
+
+ curl_slist_free_all(HeaderList);
+
+ if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ std::vector<std::pair<std::string, std::string>> PkgExtraHeaders;
+ PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage));
+ PkgExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize()));
+
+ std::string PkgBody;
+ WriteCallbackData WriteData{.Body = &PkgBody};
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+
+ CurlResult Result = Sess.Perform();
+
+ curl_slist_free_all(HeaderList);
+
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ return {.StatusCode = HttpResponseCode(Result.StatusCode)};
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size());
+
+ return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer};
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Get");
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ },
+ [this](CurlResult& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Head");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_NOBODY, 1L);
+
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Delete");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE");
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ // Rebuild headers with content type
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize()));
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ // Track requested content length from Range header (sum all ranges)
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ // Header callback that detects Content-Length and switches to file-backed storage when needed
+ struct DownloadHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ HttpContentType* ContentTypeOut = nullptr;
+ };
+
+ DownloadHeaderCallbackData DlHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ DlHdrData.Headers = &ResponseHeaders;
+ DlHdrData.PayloadFile = &PayloadFile;
+ DlHdrData.PayloadString = &PayloadString;
+ DlHdrData.TempFolderPath = &TempFolderPath;
+ DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ DlHdrData.Log = m_Log;
+ DlHdrData.BoundaryParser = &BoundaryParser;
+ DlHdrData.IsMultiRange = &IsMultiRangeResponse;
+ DlHdrData.ContentTypeOut = &ContentType;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ if (Key == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (Key == "Content-Type"sv)
+ {
+ *Data->IsMultiRange = Data->BoundaryParser->Init(Value);
+ if (!*Data->IsMultiRange)
+ {
+ *Data->ContentTypeOut = ParseContentType(Value);
+ }
+ }
+ else if (Key == "Content-Range"sv)
+ {
+ if (!*Data->IsMultiRange)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value);
+ if (Range.second != 0)
+ {
+ Data->BoundaryParser->Boundaries.push_back(
+ HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = *Data->ContentTypeOut});
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData);
+
+ // Write callback that directs data to file or string
+ struct DownloadWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ };
+
+ DownloadWriteCallbackData DlWriteData;
+ DlWriteData.PayloadString = &PayloadString;
+ DlWriteData.PayloadFile = &PayloadFile;
+ DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ DlWriteData.TempFolderPath = &TempFolderPath;
+ DlWriteData.Log = m_Log;
+ DlWriteData.BoundaryParser = &BoundaryParser;
+ DlWriteData.IsMultiRange = &IsMultiRangeResponse;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->IsMultiRange)
+ {
+ Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes));
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ // Handle resume logic
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const CurlResult& R) -> bool {
+ for (const auto& [K, V] : R.Headers)
+ {
+ if (K == "Content-Range")
+ {
+ return true;
+ }
+ if (K == "Accept-Ranges" && V == "bytes")
+ {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(R))
+ {
+ return SupportsRanges(R);
+ }
+ return false;
+ };
+
+ if (ShouldResumeCheck(Res))
+ {
+ // Find Content-Length
+ std::string ContentLengthValue;
+ for (const auto& [K, V] : Res.Headers)
+ {
+ if (K == "Content-Length")
+ {
+ ContentLengthValue = V;
+ break;
+ }
+ }
+
+ if (!ContentLengthValue.empty())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ break; // No progress, abort
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session ResumeSess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ CURL* ResumeH = ResumeSess.Get();
+
+ curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken());
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList);
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L);
+
+ std::vector<std::pair<std::string, std::string>> ResumeHeaders;
+
+ struct ResumeHeaderCbData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ };
+
+ ResumeHeaderCbData ResumeHdrData;
+ ResumeHdrData.Headers = &ResumeHeaders;
+ ResumeHdrData.PayloadFile = &PayloadFile;
+ ResumeHdrData.PayloadString = &PayloadString;
+
+ auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<ResumeHeaderCbData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ if (Key == "Content-Range"sv)
+ {
+ if (Value.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Value.find('-', 6);
+ if (RangeStartEnd != std::string_view::npos)
+ {
+ const std::optional<uint64_t> Start =
+ ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize()
+ : Data->PayloadString->length();
+ if (Start.value() == DownloadedSize)
+ {
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (*Data->PayloadFile)
+ {
+ (*Data->PayloadFile)->ResetWritePos(Start.value());
+ }
+ else
+ {
+ *Data->PayloadString = Data->PayloadString->substr(0, Start.value());
+ }
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ }
+ }
+ return 0;
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(ResumeH,
+ CURLOPT_HEADERFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb));
+ curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData);
+ curl_easy_setopt(ResumeH,
+ CURLOPT_WRITEFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData);
+
+ Res = ResumeSess.Perform();
+ Res.Headers = std::move(ResumeHeaders);
+
+ curl_slist_free_all(ResumeHdrList);
+ } while (ShouldResumeCheck(Res));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ curl_slist_free_all(DlHeaders);
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Result),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
new file mode 100644
index 000000000..2a49ff308
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -0,0 +1,135 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httpclientcommon.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CurlHttpClient : public HttpClientBase
+{
+public:
+ CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction);
+ ~CurlHttpClient();
+
+ // HttpClientBase
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url,
+ CbPackage Package,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+private:
+ struct CurlResult
+ {
+ long StatusCode = 0;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> Headers;
+ double ElapsedSeconds = 0;
+ int64_t UploadedBytes = 0;
+ int64_t DownloadedBytes = 0;
+ CURLcode ErrorCode = CURLE_OK;
+ std::string ErrorMessage;
+ };
+
+ struct Session
+ {
+ Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {}
+ ~Session() { Outer->ReleaseSession(Handle); }
+
+ CURL* Get() const { return Handle; }
+
+ CurlResult Perform();
+
+ LoggerRef Log() { return Outer->Log(); }
+
+ private:
+ CurlHttpClient* Outer;
+ CURL* Handle;
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+ };
+
+ Session AllocSession(std::string_view BaseUrl,
+ std::string_view Url,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken);
+
+ RwLock m_SessionLock;
+ std::vector<CURL*> m_Sessions;
+
+ void ReleaseSession(CURL* Handle);
+
+ struct RetryResult
+ {
+ CurlResult Result;
+ };
+
+ CurlResult DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+ CurlResult DoWithRetry(
+ std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; });
+
+ bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+
+ static bool ShouldRetry(const CurlResult& Result);
+
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
+
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
+
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 9497dadb8..792848a6b 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -10,6 +10,9 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#include <deque>
@@ -47,11 +50,7 @@ struct HttpWsClient::Impl
m_WorkGuard.reset();
// Close the socket to cancel pending async ops
- if (m_Socket)
- {
- asio::error_code Ec;
- m_Socket->close(Ec);
- }
+ CloseSocket();
if (m_IoThread.joinable())
{
@@ -59,6 +58,35 @@ struct HttpWsClient::Impl
}
}
+ void CloseSocket()
+ {
+ asio::error_code Ec;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ m_UnixSocket->close(Ec);
+ return;
+ }
+#endif
+ if (m_TcpSocket)
+ {
+ m_TcpSocket->close(Ec);
+ }
+ }
+
+ template<typename Fn>
+ void WithSocket(Fn&& Func)
+ {
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ Func(*m_UnixSocket);
+ return;
+ }
+#endif
+ Func(*m_TcpSocket);
+ }
+
void ParseUrl(std::string_view Url)
{
// Expected format: ws://host:port/path
@@ -101,9 +129,47 @@ struct HttpWsClient::Impl
m_IoThread = std::thread([this] { m_IoContext.run(); });
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ asio::post(m_IoContext, [this] { DoConnectUnix(); });
+ return;
+ }
+#endif
+
asio::post(m_IoContext, [this] { DoResolve(); });
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void DoConnectUnix()
+ {
+ m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath);
+ CloseSocket();
+ }
+ });
+
+ asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath);
+ m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+#endif
+
void DoResolve()
{
m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
@@ -122,7 +188,7 @@ struct HttpWsClient::Impl
void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
{
- m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+ m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
// Start connect timeout timer
m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
@@ -130,15 +196,11 @@ struct HttpWsClient::Impl
if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
{
ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
- if (m_Socket)
- {
- asio::error_code CloseEc;
- m_Socket->close(CloseEc);
- }
+ CloseSocket();
}
});
- asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
if (Ec)
{
m_Timer->cancel();
@@ -194,64 +256,68 @@ struct HttpWsClient::Impl
m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
- asio::async_write(*m_Socket,
- asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
- [this](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- m_Timer->cancel();
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
- m_Handler.OnWsClose(1006, "handshake write failed");
- return;
- }
-
- DoReadHandshakeResponse();
- });
+ WithSocket([this](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ });
}
void DoReadHandshakeResponse()
{
- asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
- m_Timer->cancel();
+ WithSocket([this](auto& Socket) {
+ asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
- if (Ec)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
- m_Handler.OnWsClose(1006, "handshake read failed");
- return;
- }
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
- // Parse the response
- const auto& Data = m_ReadBuffer.data();
- std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
- // Consume the headers from the read buffer (any extra data stays for frame parsing)
- auto HeaderEnd = Response.find("\r\n\r\n");
- if (HeaderEnd != std::string::npos)
- {
- m_ReadBuffer.consume(HeaderEnd + 4);
- }
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
- // Validate 101 response
- if (Response.find("101") == std::string::npos)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
- m_Handler.OnWsClose(1006, "handshake rejected");
- return;
- }
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
- // Validate Sec-WebSocket-Accept
- std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
- if (Response.find(ExpectedAccept) == std::string::npos)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
- m_Handler.OnWsClose(1006, "invalid accept key");
- return;
- }
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
- m_IsOpen.store(true);
- m_Handler.OnWsOpen();
- EnqueueRead();
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
});
}
@@ -267,8 +333,10 @@ struct HttpWsClient::Impl
return;
}
- asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
- OnDataReceived(Ec);
+ WithSocket([this](auto& Socket) {
+ asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
});
}
@@ -414,9 +482,11 @@ struct HttpWsClient::Impl
auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
- asio::async_write(*m_Socket,
- asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
- [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ WithSocket([this, OwnedFrame](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ });
}
void OnWriteComplete(const asio::error_code& Ec)
@@ -501,11 +571,14 @@ struct HttpWsClient::Impl
// Connection state
std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<asio::steady_timer> m_Timer;
- asio::streambuf m_ReadBuffer;
- std::string m_WebSocketKey;
- std::shared_ptr<std::string> m_HandshakeBuffer;
+ std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket;
+#endif
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
// Write queue
RwLock m_WriteLock;
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 281d512cf..9baf4346e 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -40,6 +40,35 @@ extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri,
const HttpClientSettings& ConnectionSettings,
std::function<bool()>&& CheckIfAbortFunction);
+extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction);
+
+static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCpr;
+
+void
+SetDefaultHttpClientBackend(HttpClientBackend Backend)
+{
+ g_DefaultHttpClientBackend = Backend;
+}
+
+void
+SetDefaultHttpClientBackend(std::string_view Backend)
+{
+ if (Backend == "cpr")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCpr;
+ }
+ else if (Backend == "curl")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCurl;
+ }
+ else
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kDefault;
+ }
+}
+
using namespace std::literals;
//////////////////////////////////////////////////////////////////////////
@@ -104,6 +133,71 @@ HttpClientBase::GetAccessToken()
//////////////////////////////////////////////////////////////////////////
+HttpClientError::ResponseClass
+HttpClientError::GetResponseClass() const
+{
+ if (m_Error != HttpClientErrorCode::kOK)
+ {
+ switch (m_Error)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ return ResponseClass::kHttpCantConnectError;
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return ResponseClass::kHttpNoHost;
+ case HttpClientErrorCode::kInternalError:
+ case HttpClientErrorCode::kNetworkReceiveError:
+ case HttpClientErrorCode::kNetworkSendFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ return ResponseClass::kHttpTimeout;
+ case HttpClientErrorCode::kSSLConnectError:
+ case HttpClientErrorCode::kSSLCertificateError:
+ case HttpClientErrorCode::kSSLCACertError:
+ case HttpClientErrorCode::kGenericSSLError:
+ return ResponseClass::kHttpSLLError;
+ default:
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ else if (IsHttpSuccessCode(m_ResponseCode))
+ {
+ return ResponseClass::kSuccess;
+ }
+ else
+ {
+ switch (m_ResponseCode)
+ {
+ case HttpResponseCode::Unauthorized:
+ return ResponseClass::kHttpUnauthorized;
+ case HttpResponseCode::NotFound:
+ return ResponseClass::kHttpNotFound;
+ case HttpResponseCode::Forbidden:
+ return ResponseClass::kHttpForbidden;
+ case HttpResponseCode::Conflict:
+ return ResponseClass::kHttpConflict;
+ case HttpResponseCode::InternalServerError:
+ return ResponseClass::kHttpInternalServerError;
+ case HttpResponseCode::ServiceUnavailable:
+ return ResponseClass::kHttpServiceUnavailable;
+ case HttpResponseCode::BadGateway:
+ return ResponseClass::kHttpBadGateway;
+ case HttpResponseCode::GatewayTimeout:
+ return ResponseClass::kHttpGatewayTimeout;
+ default:
+ if (m_ResponseCode >= HttpResponseCode::InternalServerError)
+ {
+ return ResponseClass::kHttpOtherServerError;
+ }
+ else
+ {
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
std::vector<std::pair<uint64_t, uint64_t>>
HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const
{
@@ -222,7 +316,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const
{
if (Error.has_value())
{
- return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage);
+ return fmt::format("{}{}HTTP error ({}) '{}'",
+ Prefix,
+ Prefix.empty() ? ""sv : ": "sv,
+ static_cast<int>(Error->ErrorCode),
+ Error->ErrorMessage);
}
else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode)
{
@@ -245,19 +343,34 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix)
{
if (!IsSuccess())
{
- throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode);
+ throw HttpClientError(ErrorMessage(ErrorPrefix),
+ Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK,
+ StatusCode);
}
}
//////////////////////////////////////////////////////////////////////////
HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
-: m_BaseUri(BaseUri)
+: m_Log(zen::logging::Get(ConnectionSettings.LogCategory))
+, m_BaseUri(BaseUri)
, m_ConnectionSettings(ConnectionSettings)
{
m_SessionId = GetSessionIdString();
- m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ HttpClientBackend EffectiveBackend =
+ ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend;
+
+ switch (EffectiveBackend)
+ {
+ case HttpClientBackend::kCurl:
+ m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+ case HttpClientBackend::kCpr:
+ default:
+ m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+ }
}
HttpClient::~HttpClient()
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
index 52bf149a7..2d949c546 100644
--- a/src/zenhttp/httpclient_test.cpp
+++ b/src/zenhttp/httpclient_test.cpp
@@ -8,6 +8,7 @@
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinaryutil.h>
# include <zencore/compositebuffer.h>
+# include <zencore/filesystem.h>
# include <zencore/iobuffer.h>
# include <zencore/logging.h>
# include <zencore/scopeguard.h>
@@ -232,7 +233,7 @@ struct TestServerFixture
TestServerFixture()
{
Server = CreateHttpAsioServer(AsioConfig{});
- Port = Server->Initialize(7600, TmpDir.Path());
+ Port = Server->Initialize(0, TmpDir.Path());
ZEN_ASSERT(Port != -1);
Server->RegisterService(TestService);
ServerThread = std::thread([this]() { Server->Run(false); });
@@ -1044,13 +1045,22 @@ struct FaultTcpServer
{
m_Port = m_Acceptor.local_endpoint().port();
StartAccept();
- m_Thread = std::thread([this]() { m_IoContext.run(); });
+ m_Thread = std::thread([this]() {
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (...)
+ {
+ }
+ });
}
~FaultTcpServer()
{
- std::error_code Ec;
- m_Acceptor.close(Ec);
+ // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this
+ // thread — ASIO I/O objects are not safe for concurrent access and the io_context
+ // thread may be touching the acceptor in StartAccept().
m_IoContext.stop();
if (m_Thread.joinable())
{
@@ -1081,6 +1091,105 @@ struct FaultTcpServer
}
};
+TEST_CASE("httpclient.range-response")
+{
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("single range 206 response populates Ranges")
+ {
+ std::string RangeBody(100, 'A');
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 200-299/1000\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ RangeBody.size(),
+ RangeBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 1);
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 200);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 100);
+ }
+
+ SUBCASE("multipart byteranges 206 response populates Ranges")
+ {
+ std::string Part1Data(16, 'X');
+ std::string Part2Data(12, 'Y');
+ std::string Boundary = "testboundary123";
+
+ std::string MultipartBody = fmt::format(
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 100-115/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 500-511/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}--",
+ Boundary,
+ Part1Data,
+ Boundary,
+ Part2Data,
+ Boundary);
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: multipart/byteranges; boundary={}\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ Boundary,
+ MultipartBody.size(),
+ MultipartBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 2);
+ // Ranges should be sorted by RangeOffset
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 100);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 16);
+ CHECK_EQ(Resp.Ranges[1].RangeOffset, 500);
+ CHECK_EQ(Resp.Ranges[1].RangeLength, 12);
+ }
+
+ SUBCASE("non-range 200 response has empty Ranges")
+ {
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "full content");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.Ranges.empty());
+ }
+}
+
TEST_CASE("httpclient.transport-faults" * doctest::skip())
{
SUBCASE("connection reset before response")
@@ -1354,6 +1463,188 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
}
}
+TEST_CASE("httpclient.unixsocket")
+{
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "zen.sock").string();
+
+ HttpClientTestService TestService;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ HttpClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpClient Client("localhost", Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over unix socket")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over unix socket")
+ {
+ const char* Payload = "unix socket payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "unix socket payload");
+ }
+}
+
+# if ZEN_USE_OPENSSL
+
+TEST_CASE("httpclient.https")
+{
+ // Self-signed test certificate for localhost/127.0.0.1, valid until 2036
+ static constexpr std::string_view TestCertPem =
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIIDJTCCAg2gAwIBAgIUEtJYMSUmJmvJ157We/qXNVJ7W8gwDQYJKoZIhvcNAQEL\n"
+ "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMwOTIwMjU1M1oXDTM2MDMw\n"
+ "NjIwMjU1M1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\n"
+ "AAOCAQ8AMIIBCgKCAQEAv9YvZ6WeBz3z/Zuxi6OIivWksDxDZZ5oAXKVwlUXaa7v\n"
+ "iDkm9P5ZsEhN+M5vZMe2Yb9i3cnTUaE6Avs1ddOwTAYNGrE/B5DmibrRWc23R0cv\n"
+ "gdnYQJ+gjsAeMvUWYLK58xW4YoMR5bmfpj1ruqobUNkG/oJYnAUcjgo4J149irW+\n"
+ "4n9uLJvxL+5fI/b/AIkv+4TMe70/d/BPmnixWrrzxUT6S5ghE2Mq7+XLScfpY2Sp\n"
+ "GQ/Xbnj9/ELYLpQnNLuVZwWZDpXj+FLbF1zxgjYdw1cCjbRcOIEW2/GJeJvGXQ6Y\n"
+ "Vld5pCBm9uKPPLWoFCoakK5YvP00h+8X+HghGVSscQIDAQABo28wbTAdBgNVHQ4E\n"
+ "FgQUgM6hjymi6g2EBUg2ENu0nIK8yhMwHwYDVR0jBBgwFoAUgM6hjymi6g2EBUg2\n"
+ "ENu0nIK8yhMwDwYDVR0TAQH/BAUwAwEB/zAaBgNVHREEEzARhwR/AAABgglsb2Nh\n"
+ "bGhvc3QwDQYJKoZIhvcNAQELBQADggEBABY1oaaWwL4RaK/epKvk/IrmVT2mlAai\n"
+ "uvGLfjhc6FGvXaxPGTSUPrVbFornaWZAg7bOWCexWnEm2sWd75V/usvZAPN4aIiD\n"
+ "H66YQipq3OD4F9Gowp01IU4AcGh7MerFpYPk76+wp2ANq71x8axtlZjVn3hSFMmN\n"
+ "i6m9S/eyCl9WjYBT5ZEC4fJV0nOSmNe/+gCAm11/js9zNfXKmUchJtuZpubY3A0k\n"
+ "X2II6qYWf1PH+JJkefNZtt2c66CrEN5eAg4/rGEgsp43zcd4ZHVkpBKFLDEls1ev\n"
+ "drQ45zc4Ht77pHfnHu7YsLcRZ9Wq3COMNZYx5lItqnomX2qBm1pkwjI=\n"
+ "-----END CERTIFICATE-----\n";
+
+ static constexpr std::string_view TestKeyPem =
+ "-----BEGIN PRIVATE KEY-----\n"
+ "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/1i9npZ4HPfP9\n"
+ "m7GLo4iK9aSwPENlnmgBcpXCVRdpru+IOSb0/lmwSE34zm9kx7Zhv2LdydNRoToC\n"
+ "+zV107BMBg0asT8HkOaJutFZzbdHRy+B2dhAn6COwB4y9RZgsrnzFbhigxHluZ+m\n"
+ "PWu6qhtQ2Qb+glicBRyOCjgnXj2Ktb7if24sm/Ev7l8j9v8AiS/7hMx7vT938E+a\n"
+ "eLFauvPFRPpLmCETYyrv5ctJx+ljZKkZD9dueP38QtgulCc0u5VnBZkOleP4UtsX\n"
+ "XPGCNh3DVwKNtFw4gRbb8Yl4m8ZdDphWV3mkIGb24o88tagUKhqQrli8/TSH7xf4\n"
+ "eCEZVKxxAgMBAAECggEAILd9pDaZqfCF8SWhdQgx3Ekiii/s6qLGaCDLq7XpZUvB\n"
+ "bEEbBMNwNmFOcvV6B/0LfMYwLVUjZhOSGjoPlwXAVmbdy0SZVEgBGVI0LBWqgUyB\n"
+ "rKqjd/oBXvci71vfMiSpE+0LYjmqTryGnspw2gfy2qn4yGUgiZNRmGPjycsHweUL\n"
+ "V3FHm3cf0dyE4sJ0mjVqZzRT/unw2QOCE6FlY7M1XxZL88IWfn6G4lckdJTwoOP5\n"
+ "VPR2J3XbyhvCeXeDRCHKRXojWWR2HovWnDXQc95GRgCd0vYdHuIUM6RXVPZQvy3X\n"
+ "l0GwQKHNcVr1uwtYDgGKw0tNCUDvxdfQaWilTFuicQKBgQDvEYp+vL1hnF+AVdu3\n"
+ "elsYsHpFgExkTI8wnUMvGZrFiIQyCyVDU3jkG3kcKacI1bfwopXopaQCjrYk9epm\n"
+ "liOVm3/Xtr6e2ENa7w8TQbdK65PciQNOMxml6g8clRRBl0cwj+aI3nW/Kop1cdrR\n"
+ "A9Vo+8iPTO5gDcxTiIb45a6E3QKBgQDNbE009P6ewx9PU7Llkhb9VBgsb7oQN3EV\n"
+ "TCYd4taiN6FPnTuL/cdijAA8y04hiVT+Efo9TUN9NCl9HdHXQcjj7/n/eFLH0Pkw\n"
+ "OIK3QN49OfR88wivLMtwWxIog0tJjc9+7dR4bR4o1jTlIrasEIvUTuDJQ8MKGc9v\n"
+ "pBITua+SpQKBgE4raSKZqj7hd6Sp7kbnHiRLiB9znQbqtaNKuK4M7DuMsNUAKfYC\n"
+ "tDO5+/bGc9SCtTtcnjHM/3zKlyossrFKhGYlyz6IhXnA8v0nz8EXKsy3jMh+kHMg\n"
+ "aFGE394TrOTphyCM3O+B9fRE/7L5QHg5ja1fLqwUlpkXyejCaoe16kONAoGAYIz9\n"
+ "wN1B67cEOVG6rOI8QfdLoV8mEcctNHhlFfjvLrF89SGOwl6WX0A0QF7CK0sUEpK6\n"
+ "jiOJjAh/U5o3bbgyxsedNjEEn3weE0cMUTuA+UALJMtKEqO4PuffIgGL2ld35k28\n"
+ "ZpnK6iC8HdJyD297eV9VkeNygYXeFLgF8xV8ay0CgYEAh4fmVZt9YhgVByYny2kF\n"
+ "ZUIkGF5h9wxzVOPpQwpizIGFFb3i/ZdGQcuLTfIBVRKf50sT3IwJe65ATv6+Lz0f\n"
+ "wg/pMvosi0/F5KGbVRVdzBMQy58WyyGti4tNl+8EXGvo8+DCmjlTYwfjRoZGg/qJ\n"
+ "EMP3/hTN7dHDRxPK8E0Fh0Y=\n"
+ "-----END PRIVATE KEY-----\n";
+
+ ScopedTemporaryDirectory TmpDir;
+
+ // Write cert and key to temp files
+ const auto CertPath = TmpDir.Path() / "test.crt";
+ const auto KeyPath = TmpDir.Path() / "test.key";
+ WriteFile(CertPath, IoBuffer(IoBuffer::Clone, TestCertPem.data(), TestCertPem.size()));
+ WriteFile(KeyPath, IoBuffer(IoBuffer::Clone, TestKeyPem.data(), TestKeyPem.size()));
+
+ HttpClientTestService TestService;
+
+ AsioConfig Config;
+ Config.CertFile = CertPath.string();
+ Config.KeyFile = KeyPath.string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(Config);
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ int HttpsPort = Server->GetEffectiveHttpsPort();
+ REQUIRE(HttpsPort > 0);
+
+ HttpClientSettings Settings;
+ Settings.InsecureSsl = true;
+
+ HttpClient Client(fmt::format("https://127.0.0.1:{}", HttpsPort), Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over HTTPS")
+ {
+ const char* Payload = "https payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "https payload");
+ }
+
+ SUBCASE("GET JSON over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["ok"].AsBool(), true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("Large payload over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+# endif // ZEN_USE_OPENSSL
+
TEST_SUITE_END();
void
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index 9bae95690..1a0018908 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -1044,13 +1044,16 @@ HttpServer::OnGetExternalHost() const
std::string
HttpServer::GetServiceUri(const HttpService* Service) const
{
+ const char* Scheme = (m_EffectiveHttpsPort > 0) ? "https" : "http";
+ int Port = (m_EffectiveHttpsPort > 0) ? m_EffectiveHttpsPort : m_EffectivePort;
+
if (Service)
{
- return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri());
+ return fmt::format("{}://{}:{}{}", Scheme, m_ExternalHost, Port, Service->BaseUri());
}
else
{
- return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort);
+ return fmt::format("{}://{}:{}", Scheme, m_ExternalHost, Port);
}
}
@@ -1152,9 +1155,13 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
if (ServerClass == "asio"sv)
{
ZEN_INFO("using asio HTTP server implementation")
- return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount,
- .ForceLoopback = Config.ForceLoopback,
- .IsDedicatedServer = Config.IsDedicatedServer});
+ return CreateHttpAsioServer(AsioConfig {
+ .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer,
+ .UnixSocketPath = Config.UnixSocketPath,
+#if ZEN_USE_OPENSSL
+ .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile,
+#endif
+ });
}
#if ZEN_WITH_HTTPSYS
else if (ServerClass == "httpsys"sv)
@@ -1165,7 +1172,11 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
.IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled,
.IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled,
.IsDedicatedServer = Config.IsDedicatedServer,
- .ForceLoopback = Config.ForceLoopback}));
+ .ForceLoopback = Config.ForceLoopback,
+ .HttpsPort = Config.HttpSys.HttpsPort,
+ .CertThumbprint = Config.HttpSys.CertThumbprint,
+ .CertStoreName = Config.HttpSys.CertStoreName,
+ .HttpsOnly = Config.HttpSys.HttpsOnly}));
}
#endif
else if (ServerClass == "null"sv)
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index 57ab01158..90180391c 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -84,7 +84,7 @@ struct fmt::formatter<zen::HttpClient::Response>
return fmt::format_to(Ctx.out(),
"Failed: Elapsed: {}, Reason: ({}) '{}",
NiceResponseTime,
- Response.Error.value().ErrorCode,
+ static_cast<int>(Response.Error.value().ErrorCode),
Response.Error.value().ErrorMessage);
}
else
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 1bb36a298..2e21e3bd6 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -30,6 +30,34 @@ class CompositeBuffer;
*/
+enum class HttpClientErrorCode : int
+{
+ kOK = 0,
+ kConnectionFailure,
+ kHostResolutionFailure,
+ kProxyResolutionFailure,
+ kInternalError,
+ kNetworkReceiveError,
+ kNetworkSendFailure,
+ kOperationTimedOut,
+ kSSLConnectError,
+ kSSLCertificateError,
+ kSSLCACertError,
+ kGenericSSLError,
+ kRequestCancelled,
+ kOtherError,
+};
+
+enum class HttpClientBackend : uint8_t
+{
+ kDefault,
+ kCpr,
+ kCurl,
+};
+
+void SetDefaultHttpClientBackend(std::string_view Backend);
+void SetDefaultHttpClientBackend(HttpClientBackend Backend);
+
struct HttpClientAccessToken
{
using Clock = std::chrono::system_clock;
@@ -59,6 +87,22 @@ struct HttpClientSettings
Oid SessionId = Oid::Zero;
bool Verbose = false;
uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u;
+ HttpClientBackend Backend = HttpClientBackend::kDefault;
+
+ /// Unix domain socket path. When non-empty, the client connects via this
+ /// socket instead of TCP. BaseUri is still used for the Host header and URL.
+ std::string UnixSocketPath;
+
+ /// Disable HTTP keep-alive by closing the connection after each request.
+ /// Useful for testing per-connection overhead.
+ bool ForbidReuseConnection = false;
+
+ /// Skip TLS certificate verification (for testing with self-signed certs).
+ bool InsecureSsl = false;
+
+ /// CA certificate bundle path for TLS verification. When non-empty, overrides
+ /// the system default CA store.
+ std::string CaBundlePath;
/// HTTP status codes that are expected and should not be logged as warnings.
/// 404 is always treated as expected regardless of this list.
@@ -70,22 +114,22 @@ class HttpClientError : public std::runtime_error
public:
using _Mybase = runtime_error;
- HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const std::string& Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const char* Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- inline int GetInternalErrorCode() const { return m_Error; }
- inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
+ inline HttpClientErrorCode GetInternalErrorCode() const { return m_Error; }
+ inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
enum class ResponseClass : std::int8_t
{
@@ -112,8 +156,8 @@ public:
ResponseClass GetResponseClass() const;
private:
- const int m_Error = 0;
- const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
+ const HttpClientErrorCode m_Error = HttpClientErrorCode::kOK;
+ const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
};
class HttpClientBase;
@@ -137,11 +181,23 @@ public:
struct ErrorContext
{
- int ErrorCode = 0;
- std::string ErrorMessage;
+ HttpClientErrorCode ErrorCode;
+ std::string ErrorMessage;
/** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */
- bool IsConnectionError() const;
+ bool IsConnectionError() const
+ {
+ switch (ErrorCode)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return true;
+ default:
+ return false;
+ }
+ }
};
struct KeyValueMap
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 0e1714669..d98877d16 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -255,6 +255,9 @@ public:
*/
std::string_view GetExternalHost() const { return m_ExternalHost; }
+ /** Returns the effective HTTPS port, or 0 if HTTPS is not enabled. Only valid after Initialize(). */
+ int GetEffectiveHttpsPort() const { return m_EffectiveHttpsPort; }
+
/** Returns total bytes received and sent across all connections since server start. */
virtual uint64_t GetTotalBytesReceived() const { return 0; }
virtual uint64_t GetTotalBytesSent() const { return 0; }
@@ -290,7 +293,8 @@ public:
private:
std::vector<HttpService*> m_KnownServices;
- int m_EffectivePort = 0;
+ int m_EffectivePort = 0;
+ int m_EffectiveHttpsPort = 0;
std::string m_ExternalHost;
metrics::Meter m_RequestMeter;
std::string m_DefaultRedirect;
@@ -308,6 +312,7 @@ private:
virtual void OnClose() = 0;
protected:
+ void SetEffectiveHttpsPort(int Port) { m_EffectiveHttpsPort = Port; }
virtual std::string OnGetExternalHost() const;
};
@@ -324,12 +329,20 @@ struct HttpServerConfig
std::vector<HttpServerPluginConfig> PluginConfigs;
bool ForceLoopback = false;
unsigned int ThreadCount = 0;
+ std::string UnixSocketPath; // Unix domain socket path (empty = disabled, non-Windows only)
+ int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend)
+ std::string CertFile; // PEM certificate chain file path
+ std::string KeyFile; // PEM private key file path
struct
{
unsigned int AsyncWorkThreadCount = 0;
bool IsAsyncResponseEnabled = true;
bool IsRequestLoggingEnabled = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
} HttpSys;
};
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
index 926ec1e3d..34d338b1d 100644
--- a/src/zenhttp/include/zenhttp/httpwsclient.h
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -43,6 +43,10 @@ struct HttpWsClientSettings
std::string LogCategory = "wsclient";
std::chrono::milliseconds ConnectTimeout{5000};
std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+
+ /// Unix domain socket path. When non-empty, connects via this socket
+ /// instead of TCP. The URL host is still used for the Host header.
+ std::string UnixSocketPath;
};
/**
diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h
new file mode 100644
index 000000000..25aeaa24e
--- /dev/null
+++ b/src/zenhttp/servers/asio_socket_traits.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_http {
+
+/**
+ * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets.
+ * SSL sockets need lowest_layer() access and have different shutdown semantics.
+ */
+template<typename SocketType>
+struct SocketTraits
+{
+ /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because
+ /// those bypass the encryption layer. This flag lets templated code fall back
+ /// to reading-into-memory for SSL connections.
+ static constexpr bool IsSslSocket = false;
+
+ static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); }
+
+ static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); }
+};
+
+#if ZEN_USE_OPENSSL
+using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>;
+
+template<>
+struct SocketTraits<SslSocket>
+{
+ static constexpr bool IsSslSocket = true;
+
+ static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SslSocket& S, std::error_code& Ec)
+ {
+ // Best-effort SSL close_notify, then TCP shutdown
+ S.shutdown(Ec);
+ S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec);
+ }
+
+ static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); }
+};
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index f5178ebe8..ee8e71256 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpasio.h"
+#include "asio_socket_traits.h"
#include "httptracer.h"
#include <zencore/except.h>
@@ -35,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#include <asio.hpp>
#include <asio/stream_file.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#define ASIO_VERBOSE_TRACE 0
@@ -144,7 +151,17 @@ using namespace std::literals;
struct HttpAcceptor;
struct HttpResponse;
-struct HttpServerConnection;
+template<typename SocketType>
+struct HttpServerConnectionT;
+using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+struct UnixAcceptor;
+using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>;
+#endif
+#if ZEN_USE_OPENSSL
+struct HttpsAcceptor;
+using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>;
+#endif
inline LoggerRef
InitLogger()
@@ -176,9 +193,9 @@ Log()
#endif
#if ZEN_USE_TRANSMITFILE
-template<typename Handler>
+template<typename Handler, typename SocketType>
void
-TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
+TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
{
# if ZEN_BUILD_DEBUG
const uint64_t FileSize = FileSizeFromHandle(FileHandle);
@@ -511,11 +528,20 @@ public:
bool IsLoopbackOnly() const;
+ int GetEffectiveHttpsPort() const;
+
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
- std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor;
+#endif
+#if ZEN_USE_OPENSSL
+ std::unique_ptr<asio::ssl::context> m_SslContext;
+ std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor;
+#endif
+ std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -573,6 +599,7 @@ public:
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
bool m_IsLocalMachineRequest;
+ bool m_AllowZeroCopyFileSend = true;
std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -595,6 +622,8 @@ public:
~HttpResponse() = default;
+ void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
+
/**
* Initialize the response for sending a payload made up of multiple blobs
*
@@ -636,7 +665,7 @@ public:
bool ChunkHandled = false;
#if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE
- if (OwnedBuffer.IsWholeFile())
+ if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile())
{
if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef))
{
@@ -751,7 +780,8 @@ public:
return m_Headers;
}
- void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
+ template<typename SocketType>
+ void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
{
ZEN_ASSERT(m_State == State::kInitialized);
@@ -761,10 +791,11 @@ public:
m_SendCb = std::move(Token);
m_State = State::kSending;
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
- void SendNextChunk(asio::ip::tcp::socket& TcpSocket)
+ template<typename SocketType>
+ void SendNextChunk(SocketType& Socket)
{
ZEN_ASSERT(m_State == State::kSending);
@@ -781,12 +812,12 @@ public:
auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); };
- asio::defer(TcpSocket.get_executor(), std::move(CompletionToken));
+ asio::defer(Socket.get_executor(), std::move(CompletionToken));
return;
}
- auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) {
+ auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) {
ZEN_ASSERT(m_State == State::kSending);
m_TotalBytesSent += ByteCount;
@@ -797,7 +828,7 @@ public:
}
else
{
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
};
@@ -811,25 +842,21 @@ public:
Io.Ref.FileRef.FileChunkSize);
#if ZEN_USE_TRANSMITFILE
- TransmitFileAsync(TcpSocket,
+ TransmitFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize),
OnCompletion);
+ return;
#elif ZEN_USE_ASYNC_SENDFILE
- SendFileAsync(TcpSocket,
+ SendFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
Io.Ref.FileRef.FileChunkSize,
64 * 1024,
OnCompletion);
-#else
- // This should never occur unless we compile with one
- // of the options above
- ZEN_WARN("invalid file reference in response");
-#endif
-
return;
+#endif
}
// Send as many consecutive non-file references as possible in one asio operation
@@ -850,7 +877,7 @@ public:
++m_IoVecCursor;
}
- asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
+ asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
}
private:
@@ -863,12 +890,13 @@ private:
kFailed
};
- uint32_t m_RequestNumber = 0;
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- State m_State = State::kUninitialized;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ uint32_t m_RequestNumber = 0;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ bool m_AllowZeroCopyFileSend = true;
+ State m_State = State::kUninitialized;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -895,12 +923,13 @@ private:
//////////////////////////////////////////////////////////////////////////
-struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection>
+template<typename SocketType>
+struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>>
{
- HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
- ~HttpServerConnection();
+ HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket);
+ ~HttpServerConnectionT();
- std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+ std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); }
// HttpConnectionBase implementation
@@ -962,12 +991,13 @@ private:
RwLock m_ActiveResponsesLock;
std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<SocketType> m_Socket;
};
std::atomic<uint32_t> g_ConnectionIdCounter{0};
-HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket)
: m_Server(Server)
, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
, m_Socket(std::move(Socket))
@@ -975,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq
ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
}
-HttpServerConnection::~HttpServerConnection()
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::~HttpServerConnectionT()
{
RwLock::ExclusiveLockScope _(m_ActiveResponsesLock);
ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
}
+template<typename SocketType>
void
-HttpServerConnection::HandleNewRequest()
+HttpServerConnectionT<SocketType>::HandleNewRequest()
{
EnqueueRead();
}
+template<typename SocketType>
void
-HttpServerConnection::TerminateConnection()
+HttpServerConnectionT<SocketType>::TerminateConnection()
{
if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated)
{
@@ -1001,12 +1034,13 @@ HttpServerConnection::TerminateConnection()
// Terminating, we don't care about any errors when closing socket
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_both, Ec);
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
}
+template<typename SocketType>
void
-HttpServerConnection::EnqueueRead()
+HttpServerConnectionT<SocketType>::EnqueueRead()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1027,8 +1061,9 @@ HttpServerConnection::EnqueueRead()
[Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
}
+template<typename SocketType>
void
-HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1086,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+template<typename SocketType>
void
-HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
- [[maybe_unused]] std::size_t ByteCount,
- [[maybe_unused]] uint32_t RequestNumber,
- HttpResponse* ResponseToPop)
+HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec,
+ [[maybe_unused]] std::size_t ByteCount,
+ [[maybe_unused]] uint32_t RequestNumber,
+ HttpResponse* ResponseToPop)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1144,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
}
}
+template<typename SocketType>
void
-HttpServerConnection::CloseConnection()
+HttpServerConnectionT<SocketType>::CloseConnection()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1157,23 +1194,24 @@ HttpServerConnection::CloseConnection()
m_RequestState = RequestState::kDone;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
}
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket close ERROR, reason '{}'", Ec.message());
}
}
+template<typename SocketType>
void
-HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
- std::string_view StatusLine,
- std::string_view Headers,
- std::string_view Body)
+HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
{
ExtendableStringBuilder<256> ResponseBuilder;
ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
@@ -1194,15 +1232,16 @@ HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
asio::async_write(
- *m_Socket.get(),
+ *m_Socket,
Buffer,
[Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
});
}
+template<typename SocketType>
void
-HttpServerConnection::HandleRequest()
+HttpServerConnectionT<SocketType>::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1229,24 +1268,25 @@ HttpServerConnection::HandleRequest()
ResponseStr->append("\r\n\r\n");
// Send the 101 response on the current socket, then hand the socket off
- // to a WsAsioConnection for the WebSocket protocol.
- asio::async_write(*m_Socket,
- asio::buffer(ResponseStr->data(), ResponseStr->size()),
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
- return;
- }
-
- Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
- Ref<WsAsioConnection> WsConn(
- new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
- Ref<WebSocketConnection> WsConnRef(WsConn.Get());
-
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
- WsConn->Start();
- });
+ // to a WsAsioConnectionT for the WebSocket protocol.
+ asio::async_write(
+ *m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ using WsConnType = WsAsioConnectionT<SocketType>;
+ Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
m_RequestState = RequestState::kDone;
return;
@@ -1260,7 +1300,7 @@ HttpServerConnection::HandleRequest()
m_RequestState = RequestState::kWritingFinal;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
@@ -1280,15 +1320,36 @@ HttpServerConnection::HandleRequest()
m_Server.m_HttpServer->MarkRequest();
- auto RemoteEndpoint = m_Socket->remote_endpoint();
- bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ bool IsLocalConnection = true;
+ std::string RemoteAddress;
+
+ if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>)
+ {
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#if ZEN_USE_OPENSSL
+ else if constexpr (std::is_same_v<SocketType, SslSocket>)
+ {
+ auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint();
+ IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#endif
+ else
+ {
+ RemoteAddress = "unix";
+ }
HttpAsioServerRequest Request(m_RequestData,
*Service,
m_RequestData.Body(),
RequestNumber,
IsLocalConnection,
- RemoteEndpoint.address().to_string());
+ std::move(RemoteAddress));
+
+ Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket;
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
@@ -1439,14 +1500,23 @@ HttpServerConnection::HandleRequest()
}
//////////////////////////////////////////////////////////////////////////
+// Base class for TCP acceptors that handles socket setup, port binding
+// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support.
+// Subclasses only need to implement OnAccept() to handle new connections.
-struct HttpAcceptor
+struct TcpAcceptorBase
{
- HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ TcpAcceptorBase(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ uint16_t BasePort,
+ bool ForceLoopback,
+ bool AllowPortProbing,
+ std::string_view Label)
: m_Server(Server)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
, m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4())
+ , m_Label(Label)
{
const bool IsUsingIPv6 = IsIPv6Capable();
if (!IsUsingIPv6)
@@ -1455,7 +1525,6 @@ struct HttpAcceptor
}
#if ZEN_PLATFORM_WINDOWS
- // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
@@ -1468,83 +1537,54 @@ struct HttpAcceptor
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- std::string BoundBaseUrl;
if (IsUsingIPv6)
{
- BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
+ BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
}
else
{
- ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only");
-
- BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
+ ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label);
+ BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
}
+ }
- if (!IsValid())
- {
- return;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor.native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-
- if (m_UseAlternateProtocolAcceptor)
- {
- NativeSocket = m_AlternateProtocolAcceptor.native_handle();
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
- }
-#endif
- m_Acceptor.listen();
+ virtual ~TcpAcceptorBase()
+ {
+ m_Acceptor.close();
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.listen();
+ m_AlternateProtocolAcceptor.close();
}
-
- ZEN_INFO("Started asio server at '{}", BoundBaseUrl);
}
- ~HttpAcceptor()
+ void Start()
{
- m_Acceptor.close();
+ ZEN_ASSERT(!m_IsStopped);
+ InitAcceptLoop(m_Acceptor);
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.close();
+ InitAcceptLoop(m_AlternateProtocolAcceptor);
}
}
+ void StopAccepting() { m_IsStopped = true; }
+
+ uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
+ bool IsValid() const { return m_IsValid; }
+
+protected:
+ /// Called for each accepted TCP socket. Subclasses create the appropriate connection type.
+ virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0;
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+
+private:
template<typename AddressType>
- std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
{
uint16_t EffectivePort = BasePort;
@@ -1571,7 +1611,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback())
{
- ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort);
+ ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort);
BindAddress = AddressType::loopback();
@@ -1585,7 +1625,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::address_in_use)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1601,7 +1641,8 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ m_Label,
BindErrorCode.message());
EffectivePort = 0;
@@ -1617,7 +1658,7 @@ struct HttpAcceptor
{
for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1625,14 +1666,13 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
-
- return {};
+ ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message());
+ return;
}
if (EffectivePort != BasePort)
{
- ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort);
+ ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort);
}
if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
@@ -1642,55 +1682,64 @@ struct HttpAcceptor
// IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does
// IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1)
- m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode);
+ asio::error_code AltEc;
+ m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc);
- if (BindErrorCode)
+ if (AltEc)
{
- ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort);
+ ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort);
}
else
{
m_UseAlternateProtocolAcceptor = true;
- ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts",
- "localhost",
- EffectivePort);
}
}
}
- m_IsValid = true;
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
- if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
- {
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort);
- }
- else
+ if (m_UseAlternateProtocolAcceptor)
{
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort);
+ NativeSocket = m_AlternateProtocolAcceptor.native_handle();
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
}
- }
-
- void Start()
- {
- ZEN_MEMSCOPE(GetHttpasioTag());
+#endif
- ZEN_ASSERT(!m_IsStopped);
- InitAcceptInternal(m_Acceptor);
+ m_Acceptor.listen();
if (m_UseAlternateProtocolAcceptor)
{
- InitAcceptInternal(m_AlternateProtocolAcceptor);
+ m_AlternateProtocolAcceptor.listen();
}
- }
- void StopAccepting() { m_IsStopped = true; }
-
- int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
- bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
-
- bool IsValid() const { return m_IsValid; }
+ m_IsValid = true;
+ ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port());
+ }
-private:
- void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor)
+ void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor)
{
auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
@@ -1698,29 +1747,19 @@ private:
Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- Acceptor.local_endpoint().address().to_string(),
- Acceptor.local_endpoint().port(),
- Ec.message());
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message());
+ }
}
else
{
- // New connection established, pass socket ownership into connection object
- // and initiate request handling loop. The connection lifetime is
- // managed by the async read/write loop by passing the shared
- // reference to the callbacks.
-
- Socket->set_option(asio::ip::tcp::no_delay(true));
- Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
- Conn->HandleNewRequest();
+ OnAccept(std::move(Socket));
}
if (!m_IsStopped.load())
{
- InitAcceptInternal(Acceptor);
+ InitAcceptLoop(Acceptor);
}
else
{
@@ -1728,21 +1767,204 @@ private:
Acceptor.close(CloseEc);
if (CloseEc)
{
- ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
+ ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message());
}
}
});
}
- HttpAsioServerImpl& m_Server;
- asio::io_service& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
asio::ip::tcp::acceptor m_AlternateProtocolAcceptor;
bool m_UseAlternateProtocolAcceptor{false};
bool m_IsValid{false};
std::atomic<bool> m_IsStopped{false};
+ std::string_view m_Label;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor final : TcpAcceptorBase
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP")
+ {
+ }
+
+ int GetAcceptPort() const { return GetPort(); }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
};
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+
+//////////////////////////////////////////////////////////////////////////
+
+struct UnixAcceptor
+{
+ UnixAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, const std::string& SocketPath)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService)
+ , m_SocketPath(SocketPath)
+ {
+ // Remove any stale socket file from a previous run
+ std::filesystem::remove(m_SocketPath);
+
+ asio::local::stream_protocol::endpoint Endpoint(m_SocketPath);
+
+ asio::error_code Ec;
+ m_Acceptor.open(Endpoint.protocol(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to open unix domain socket: {}", Ec.message());
+ return;
+ }
+
+ m_Acceptor.bind(Endpoint, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_IsValid = true;
+ ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath);
+ }
+
+ ~UnixAcceptor()
+ {
+ asio::error_code Ec;
+ m_Acceptor.close(Ec);
+ std::filesystem::remove(m_SocketPath);
+ }
+
+ void Start()
+ {
+ ZEN_ASSERT(!m_IsStopped);
+ InitAccept();
+ }
+
+ void StopAccepting() { m_IsStopped = true; }
+
+ bool IsValid() const { return m_IsValid; }
+
+private:
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService);
+ asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ std::error_code CloseEc;
+ m_Acceptor.close(CloseEc);
+ }
+ });
+ }
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::local::stream_protocol::acceptor m_Acceptor;
+ std::string m_SocketPath;
+ bool m_IsValid{false};
+ std::atomic<bool> m_IsStopped{false};
+};
+
+#endif // ASIO_HAS_LOCAL_SOCKETS
+
+#if ZEN_USE_OPENSSL
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpsAcceptor final : TcpAcceptorBase
+{
+ HttpsAcceptor(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ asio::ssl::context& SslContext,
+ uint16_t Port,
+ bool ForceLoopback,
+ bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS")
+ , m_SslContext(SslContext)
+ {
+ }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ // Wrap accepted TCP socket in an SSL stream and perform the handshake
+ auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext);
+
+ SslSocket& SslRef = *SslSocketPtr;
+ SslRef.async_handshake(asio::ssl::stream_base::server,
+ [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable {
+ if (HandshakeEc)
+ {
+ ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message());
+ std::error_code Ec;
+ SslSocket->lowest_layer().close(Ec);
+ return;
+ }
+
+ auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket));
+ Conn->HandleNewRequest();
+ });
+ }
+
+private:
+ asio::ssl::context& m_SslContext;
+};
+
+#endif // ZEN_USE_OPENSSL
+
+int
+HttpAsioServerImpl::GetEffectiveHttpsPort() const
+{
+#if ZEN_USE_OPENSSL
+ return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0;
+#else
+ return 0;
+#endif
+}
+
//////////////////////////////////////////////////////////////////////////
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
@@ -1860,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -1873,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -1883,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
@@ -1942,6 +2167,51 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config)
m_Acceptor->Start();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!Config.UnixSocketPath.empty())
+ {
+ m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath));
+
+ if (m_UnixAcceptor->IsValid())
+ {
+ m_UnixAcceptor->Start();
+ }
+ else
+ {
+ m_UnixAcceptor.reset();
+ }
+ }
+#endif
+
+#if ZEN_USE_OPENSSL
+ if (!Config.CertFile.empty() && !Config.KeyFile.empty())
+ {
+ m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server);
+ m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1);
+ m_SslContext->use_certificate_chain_file(Config.CertFile);
+ m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem);
+
+ ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile);
+
+ m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this,
+ m_IoService,
+ *m_SslContext,
+ gsl::narrow<uint16_t>(Config.HttpsPort),
+ Config.ForceLoopback,
+ /*AllowPortProbing*/ !Config.IsDedicatedServer));
+
+ if (m_HttpsAcceptor->IsValid())
+ {
+ m_HttpsAcceptor->Start();
+ }
+ else
+ {
+ m_HttpsAcceptor.reset();
+ }
+ }
+#endif
+
// This should consist of a set of minimum threads and grow on demand to
// meet concurrency needs? Right now we end up allocating a large number
// of threads even if we never end up using all of them, which seems
@@ -1990,6 +2260,18 @@ HttpAsioServerImpl::Stop()
{
m_Acceptor->StopAccepting();
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixAcceptor)
+ {
+ m_UnixAcceptor->StopAccepting();
+ }
+#endif
+#if ZEN_USE_OPENSSL
+ if (m_HttpsAcceptor)
+ {
+ m_HttpsAcceptor->StopAccepting();
+ }
+#endif
m_IoService.stop();
for (auto& Thread : m_ThreadPool)
{
@@ -1999,7 +2281,23 @@ HttpAsioServerImpl::Stop()
}
}
m_ThreadPool.clear();
+
+ // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket
+ // connections) so that their captured Ref<> pointers are released while the
+ // io_service and its epoll reactor are still alive. Without this, sockets
+ // held by external code (e.g. IWebSocketHandler connection lists) can outlive
+ // the reactor and crash during deregistration.
+ m_IoService.restart();
+ m_IoService.poll();
+
m_Acceptor.reset();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ m_UnixAcceptor.reset();
+#endif
+#if ZEN_USE_OPENSSL
+ m_HttpsAcceptor.reset();
+ m_SslContext.reset();
+#endif
}
void
@@ -2166,6 +2464,13 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config);
+#if ZEN_USE_OPENSSL
+ if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(EffectiveHttpsPort);
+ }
+#endif
+
return m_BasePort;
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index 3ec1141a7..5adf4d5e8 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -11,6 +11,12 @@ struct AsioConfig
unsigned int ThreadCount = 0;
bool ForceLoopback = false;
bool IsDedicatedServer = false;
+ std::string UnixSocketPath;
+#if ZEN_USE_OPENSSL
+ int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS
+ std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled)
+ std::string KeyFile; // PEM private key file
+#endif
};
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index dfe6bb6aa..83b98013e 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -116,6 +116,12 @@ public:
private:
int InitializeServer(int BasePort);
+ bool CreateSessionAndUrlGroup();
+ bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris);
+ int RegisterHttpUrls(int BasePort);
+ bool RegisterHttpsUrls();
+ bool CreateRequestQueue(int EffectivePort);
+ bool SetupIoCompletionPort();
void Cleanup();
void StartServer();
@@ -125,6 +131,9 @@ private:
void RegisterService(const char* Endpoint, HttpService& Service);
void UnregisterService(const char* Endpoint, HttpService& Service);
+ bool BindSslCertificate(int Port);
+ void UnbindSslCertificate();
+
private:
LoggerRef m_Log;
LoggerRef m_RequestLog;
@@ -140,7 +149,10 @@ private:
RwLock m_AsyncWorkPoolInitLock;
std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/
+ bool m_DidAutoBindCert = false;
+ int m_HttpsPort = 0;
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
HANDLE m_RequestQueueHandle = 0;
@@ -1082,39 +1094,63 @@ HttpSysServer::OnClose()
}
}
-int
-HttpSysServer::InitializeServer(int BasePort)
+bool
+HttpSysServer::CreateSessionAndUrlGroup()
{
- ZEN_MEMSCOPE(GetHttpsysTag());
-
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
- WideToUtf8(WildcardUrlPath),
- GetSystemErrorAsString(Result),
- Result);
+ ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
+ ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
+ return true;
+}
+
+bool
+HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris)
+{
+ using namespace std::literals;
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrl;
+ LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (Result == NO_ERROR)
+ {
+ ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl));
+ OutUris.push_back(LocalUrl.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return !OutUris.empty();
+}
+
+int
+HttpSysServer::RegisterHttpUrls(int BasePort)
+{
+ using namespace std::literals;
+
m_BaseUris.clear();
const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer;
@@ -1122,6 +1158,11 @@ HttpSysServer::InitializeServer(int BasePort)
int EffectivePort = BasePort;
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ ULONG Result;
+
if (m_InitialConfig.ForceLoopback)
{
// Force trigger of opening using local port
@@ -1177,11 +1218,11 @@ HttpSysServer::InitializeServer(int BasePort)
{
if (AllowLocalOnly)
{
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
+ // If we can't register the wildcard path, we fall back to local paths.
+ // Local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
// during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
+ // port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
if (!m_InitialConfig.ForceLoopback)
@@ -1246,7 +1287,7 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- if (m_BaseUris.empty())
+ if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0)
{
ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
WideToUtf8(WildcardUrlPath),
@@ -1256,16 +1297,104 @@ HttpSysServer::InitializeServer(int BasePort)
return 0;
}
+ return EffectivePort;
+}
+
+bool
+HttpSysServer::RegisterHttpsUrls()
+{
+ using namespace std::literals;
+
+ const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer;
+ const int HttpsPort = m_InitialConfig.HttpsPort;
+
+ // If HTTPS-only mode, remove HTTP URLs and clear base URIs
+ if (m_InitialConfig.HttpsOnly)
+ {
+ for (const std::wstring& Uri : m_BaseUris)
+ {
+ HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0);
+ }
+ m_BaseUris.clear();
+ }
+
+ // Auto-bind certificate if thumbprint is provided
+ if (!m_InitialConfig.CertThumbprint.empty())
+ {
+ if (!BindSslCertificate(HttpsPort))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort);
+ }
+
+ // Register HTTPS URLs using same pattern as HTTP
+
+ WideStringBuilder<64> HttpsWildcard;
+ HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv;
+
+ ULONG HttpsResult = NO_ERROR;
+
+ if (m_InitialConfig.ForceLoopback)
+ {
+ HttpsResult = ERROR_ACCESS_DENIED;
+ }
+ else
+ {
+ HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ if (HttpsResult == NO_ERROR)
+ {
+ m_HttpsBaseUris.push_back(HttpsWildcard.c_str());
+ }
+ else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly)
+ {
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register HTTPS handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.",
+ WideToUtf8(HttpsWildcard));
+ }
+
+ RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris);
+ }
+ else if (HttpsResult != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})",
+ WideToUtf8(HttpsWildcard),
+ GetSystemErrorAsString(HttpsResult),
+ HttpsResult);
+ return false;
+ }
+
+ if (m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort);
+ return false;
+ }
+
+ m_HttpsPort = HttpsPort;
+ return true;
+}
+
+bool
+HttpSysServer::CreateRequestQueue(int EffectivePort)
+{
HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
WideStringBuilder<64> QueueName;
QueueName << "zenserver_" << EffectivePort;
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ QueueName.c_str(),
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
+ ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ QueueName.c_str(),
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
if (Result != NO_ERROR)
{
@@ -1274,7 +1403,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
HttpBindingInfo.Flags.Present = 1;
@@ -1289,7 +1418,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
// Configure rejection method. Default is to drop the connection, it's better if we
@@ -1323,22 +1452,77 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- // Create I/O completion port
+ return true;
+}
+bool
+HttpSysServer::SetupIoCompletionPort()
+{
std::error_code ErrorCode;
m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
if (ErrorCode)
{
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message());
+ return false;
+ }
+ m_IsOk = true;
+
+ if (!m_BaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+ if (!m_HttpsBaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front()));
+ }
+
+ return true;
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
+ m_IsOk = false;
+
+ if (!CreateSessionAndUrlGroup())
+ {
return 0;
}
- else
+
+ int EffectivePort = RegisterHttpUrls(BasePort);
+
+ if (m_InitialConfig.HttpsPort > 0)
+ {
+ if (!RegisterHttpsUrls())
+ {
+ return 0;
+ }
+ }
+
+ if (m_BaseUris.empty() && m_HttpsBaseUris.empty())
{
- m_IsOk = true;
+ ZEN_ERROR("No HTTP or HTTPS listeners could be registered");
+ return 0;
+ }
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ if (!CreateRequestQueue(EffectivePort))
+ {
+ return 0;
+ }
+
+ if (!SetupIoCompletionPort())
+ {
+ return 0;
+ }
+
+ // When HTTPS-only, return the HTTPS port as the effective port
+ if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0)
+ {
+ return m_HttpsPort;
}
return EffectivePort;
@@ -1349,6 +1533,8 @@ HttpSysServer::Cleanup()
{
++m_IsShuttingDown;
+ UnbindSslCertificate();
+
if (m_RequestQueueHandle)
{
HttpCloseRequestQueue(m_RequestQueueHandle);
@@ -1368,6 +1554,105 @@ HttpSysServer::Cleanup()
}
}
+// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings
+static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}};
+
+bool
+HttpSysServer::BindSslCertificate(int Port)
+{
+ const std::string& Thumbprint = m_InitialConfig.CertThumbprint;
+ if (Thumbprint.size() != 40)
+ {
+ ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size());
+ return false;
+ }
+
+ BYTE CertHash[20] = {};
+ if (!ParseHexBytes(Thumbprint, CertHash))
+ {
+ ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters");
+ return false;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(Port));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str());
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+ SslConfig.ParamDesc.pSslHash = CertHash;
+ SslConfig.ParamDesc.SslHashLength = sizeof(CertHash);
+ SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str());
+ SslConfig.ParamDesc.AppId = ZenServerSslAppId;
+
+ ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result == ERROR_ALREADY_EXISTS)
+ {
+ // Remove existing binding and retry
+ HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {};
+ DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr);
+
+ Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+ }
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR(
+ "Failed to bind SSL certificate to port {}: {} ({:#x}). "
+ "This operation may require running as administrator.",
+ Port,
+ GetSystemErrorAsString(Result),
+ Result);
+ return false;
+ }
+
+ m_DidAutoBindCert = true;
+ m_HttpsPort = Port;
+
+ ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})",
+ Port,
+ Thumbprint.substr(0, 8),
+ m_InitialConfig.CertStoreName);
+
+ return true;
+}
+
+void
+HttpSysServer::UnbindSslCertificate()
+{
+ if (!m_DidAutoBindCert)
+ {
+ return;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result);
+ }
+ else
+ {
+ ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort);
+ }
+
+ m_DidAutoBindCert = false;
+}
+
WorkerThreadPool&
HttpSysServer::WorkPool()
{
@@ -1495,19 +1780,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
// Convert to wide string
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
+ auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
{
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ std::wstring Url16 = BaseUri + PathUtf16;
- return;
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ return;
+ }
}
- }
+ };
+
+ RegisterWithBaseUris(m_BaseUris);
+ RegisterWithBaseUris(m_HttpsBaseUris);
}
void
@@ -1522,19 +1811,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
+ auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
}
- }
+ };
+
+ UnregisterFromBaseUris(m_BaseUris);
+ UnregisterFromBaseUris(m_HttpsBaseUris);
}
//////////////////////////////////////////////////////////////////////////
@@ -2422,6 +2714,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
+ if (m_HttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(m_HttpsPort);
+ }
+
StartServer();
return EffectivePort;
diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h
index b2fe7475b..ca465ad00 100644
--- a/src/zenhttp/servers/httpsys.h
+++ b/src/zenhttp/servers/httpsys.h
@@ -22,6 +22,10 @@ struct HttpSysConfig
bool IsRequestLoggingEnabled = false;
bool IsDedicatedServer = false;
bool ForceLoopback = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
};
Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config);
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index b2543277a..5ae48f5b3 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "wsasio.h"
+#include "asio_socket_traits.h"
#include "wsframecodec.h"
#include <zencore/logging.h>
@@ -17,14 +18,16 @@ WsLog()
//////////////////////////////////////////////////////////////////////////
-WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server)
: m_Socket(std::move(Socket))
, m_Handler(Handler)
, m_HttpServer(Server)
{
}
-WsAsioConnection::~WsAsioConnection()
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::~WsAsioConnectionT()
{
m_IsOpen.store(false);
if (m_HttpServer)
@@ -33,14 +36,16 @@ WsAsioConnection::~WsAsioConnection()
}
}
+template<typename SocketType>
void
-WsAsioConnection::Start()
+WsAsioConnectionT<SocketType>::Start()
{
EnqueueRead();
}
+template<typename SocketType>
bool
-WsAsioConnection::IsOpen() const
+WsAsioConnectionT<SocketType>::IsOpen() const
{
return m_IsOpen.load(std::memory_order_relaxed);
}
@@ -50,23 +55,25 @@ WsAsioConnection::IsOpen() const
// Read loop
//
+template<typename SocketType>
void
-WsAsioConnection::EnqueueRead()
+WsAsioConnectionT<SocketType>::EnqueueRead()
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
Self->OnDataReceived(Ec, ByteCount);
});
}
+template<typename SocketType>
void
-WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -90,8 +97,9 @@ WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] st
}
}
+template<typename SocketType>
void
-WsAsioConnection::ProcessReceivedData()
+WsAsioConnectionT<SocketType>::ProcessReceivedData()
{
while (m_ReadBuffer.size() > 0)
{
@@ -162,8 +170,8 @@ WsAsioConnection::ProcessReceivedData()
// Shut down the socket
std::error_code ShutdownEc;
- m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
- m_Socket->close(ShutdownEc);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc);
+ SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc);
return;
}
@@ -179,8 +187,9 @@ WsAsioConnection::ProcessReceivedData()
// Write queue
//
+template<typename SocketType>
void
-WsAsioConnection::SendText(std::string_view Text)
+WsAsioConnectionT<SocketType>::SendText(std::string_view Text)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -192,8 +201,9 @@ WsAsioConnection::SendText(std::string_view Text)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -204,14 +214,16 @@ WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason)
{
DoClose(Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason)
{
if (!m_IsOpen.exchange(false))
{
@@ -227,8 +239,9 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame)
{
if (m_HttpServer)
{
@@ -252,8 +265,9 @@ WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
}
}
+template<typename SocketType>
void
-WsAsioConnection::FlushWriteQueue()
+WsAsioConnectionT<SocketType>::FlushWriteQueue()
{
std::vector<uint8_t> Frame;
@@ -272,7 +286,7 @@ WsAsioConnection::FlushWriteQueue()
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
// Move Frame into a shared_ptr so we can create the buffer and capture ownership
// in the same async_write call without evaluation order issues.
@@ -283,8 +297,9 @@ WsAsioConnection::FlushWriteQueue()
[Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
}
+template<typename SocketType>
void
-WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -308,4 +323,17 @@ WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] s
FlushWriteQueue();
}
+//////////////////////////////////////////////////////////////////////////
+// Explicit template instantiations
+
+template class WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+template class WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
index e8bb3b1d2..64602ee46 100644
--- a/src/zenhttp/servers/wsasio.h
+++ b/src/zenhttp/servers/wsasio.h
@@ -8,6 +8,12 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#include <deque>
@@ -21,22 +27,23 @@ class HttpServer;
namespace zen::asio_http {
/**
- * WebSocket connection over an ASIO TCP socket
+ * WebSocket connection over an ASIO stream socket
*
- * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * Templated on SocketType to support both TCP and Unix domain sockets.
+ * Owns the socket (moved from HttpServerConnection after the 101 handshake)
* and runs an async read/write loop to exchange WebSocket frames.
*
* Lifetime is managed solely through intrusive reference counting (RefCounted).
- * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
- * connection alive for the duration of the async operation. The service layer
- * also holds a Ref<WebSocketConnection>.
+ * The async read/write callbacks capture Ref<> to keep the connection alive
+ * for the duration of the async operation. The service layer also holds a
+ * Ref<WebSocketConnection>.
*/
-
-class WsAsioConnection : public WebSocketConnection
+template<typename SocketType>
+class WsAsioConnectionT : public WebSocketConnection
{
public:
- WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
- ~WsAsioConnection() override;
+ WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnectionT() override;
/**
* Start the async read loop. Must be called once after construction
@@ -61,10 +68,10 @@ private:
void DoClose(uint16_t Code, std::string_view Reason);
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- IWebSocketHandler& m_Handler;
- zen::HttpServer* m_HttpServer;
- asio::streambuf m_ReadBuffer;
+ std::unique_ptr<SocketType> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
RwLock m_WriteLock;
std::deque<std::vector<uint8_t>> m_WriteQueue;
@@ -74,4 +81,14 @@ private:
std::atomic<bool> m_CloseSent{false};
};
+using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 2134e4ff1..042afd8ff 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -485,7 +485,7 @@ TEST_CASE("websocket.integration")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7575, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -797,7 +797,7 @@ TEST_CASE("websocket.client")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7576, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -913,6 +913,75 @@ TEST_CASE("websocket.client")
}
}
+TEST_CASE("websocket.client.unixsocket")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "ws.sock").string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close over unix socket")
+ {
+ TestWsClientHandler Handler;
+ HttpWsClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello over unix socket");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello over unix socket");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
TEST_SUITE_END();
void
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index e8f87b668..9b461662e 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -12,6 +12,11 @@ target('zenhttp')
add_packages("http_parser", "json11")
add_options("httpsys")
+ if is_plat("linux", "macosx") then
+ add_packages("openssl3")
+ end
+
if is_plat("linux") then
add_syslinks("dl") -- TODO: is libdl needed?
end
+