diff options
| author | Stefan Boberg <[email protected]> | 2026-03-10 17:27:26 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-10 17:27:26 +0100 |
| commit | d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch) | |
| tree | 2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src/zenhttp/clients | |
| parent | changelog entry which was inadvertently omitted from PR merge (diff) | |
| download | zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip | |
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency.
### HttpClient Backend Overhaul
- Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines)
as an alternative to the cpr-based one
- Made HttpClient backend **configurable at runtime** via constructor arguments
and `-httpclient=...` CLI option (for zen, zenserver, and tests)
- Extended HttpClient test suite to cover multipart/content-range scenarios
### Unix Domain Socket Support
- Added Unix domain socket support to **httpasio** (server side)
- Added Unix domain socket support to **HttpClient**
- Added Unix domain socket support to **HttpWsClient** (WebSocket client)
- Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>`
to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch
### HTTPS Support
- Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL)
- Added **basic HTTPS support for http.sys** (Windows)
- Implemented HTTPS test for httpasio
- Split `InitializeServer` into smaller sub-functions for http.sys
### Other Notable Changes
- Improved **zenhttp-test stability** with dynamic port allocation
- Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED)
- Fatal signal/exception handlers for backtrace generation in tests
- Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src/zenhttp/clients')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.cpp | 57 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 135 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 1947 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.h | 135 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 213 |
5 files changed, 2343 insertions, 144 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 6f4c67dd0..e4d11547a 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -646,6 +646,63 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("ParseContentRange") +{ + SUBCASE("normal range with total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500"); + CHECK_EQ(Offset, 0); + CHECK_EQ(Length, 100); + } + + SUBCASE("non-zero offset") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878"); + CHECK_EQ(Offset, 2638); + CHECK_EQ(Length, 5111437 - 2638 + 1); + } + + SUBCASE("wildcard total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*"); + CHECK_EQ(Offset, 100); + CHECK_EQ(Length, 100); + } + + SUBCASE("no slash (total size omitted)") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 50-149"); + CHECK_EQ(Offset, 50); + CHECK_EQ(Length, 100); + } + + SUBCASE("malformed input returns zeros") + { + auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500"); + CHECK_EQ(Offset1, 0); + CHECK_EQ(Length1, 0); + + auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500"); + CHECK_EQ(Offset2, 0); + CHECK_EQ(Length2, 0); + + auto [Offset3, Length3] = detail::ParseContentRange(""); + CHECK_EQ(Offset3, 0); + CHECK_EQ(Length3, 0); + + auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500"); + CHECK_EQ(Offset4, 0); + CHECK_EQ(Length4, 0); + } + + SUBCASE("single byte range") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000"); + CHECK_EQ(Offset, 42); + CHECK_EQ(Length, 1); + } +} + TEST_CASE("MultipartBoundaryParser") { uint64_t Range1Offset = 2638; diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 14e40b02a..f3082e0a2 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -14,6 +14,11 @@ #include <zenhttp/packageformat.h> #include <algorithm> +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/ssl_options.h> +#include <cpr/unix_socket.h> +ZEN_THIRD_PARTY_INCLUDES_END + namespace zen { HttpClientBase* @@ -24,84 +29,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; -bool -HttpClient::ErrorContext::IsConnectionError() const +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCprError(cpr::ErrorCode Code) { - switch (static_cast<cpr::ErrorCode>(ErrorCode)) + switch (Code) { + case cpr::ErrorCode::OK: + return HttpClientErrorCode::kOK; case cpr::ErrorCode::CONNECTION_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kConnectionFailure; case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + return HttpClientErrorCode::kHostResolutionFailure; case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return true; + return HttpClientErrorCode::kProxyResolutionFailure; + case cpr::ErrorCode::INTERNAL_ERROR: + return HttpClientErrorCode::kInternalError; + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + return HttpClientErrorCode::kNetworkSendFailure; + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + return HttpClientErrorCode::kSSLCertificateError; + case cpr::ErrorCode::SSL_CACERT_ERROR: + return HttpClientErrorCode::kSSLCACertError; + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return HttpClientErrorCode::kGenericSSLError; + case cpr::ErrorCode::REQUEST_CANCELLED: + return HttpClientErrorCode::kRequestCancelled; default: - return false; - } -} - -// If we want to support different HTTP client implementations then we'll need to make this more abstract - -HttpClientError::ResponseClass -HttpClientError::GetResponseClass() const -{ - if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) - { - switch ((cpr::ErrorCode)m_Error) - { - case cpr::ErrorCode::CONNECTION_FAILURE: - return ResponseClass::kHttpCantConnectError; - case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: - case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return ResponseClass::kHttpNoHost; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return ResponseClass::kHttpTimeout; - case cpr::ErrorCode::SSL_CONNECT_ERROR: - case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_CACERT_ERROR: - case cpr::ErrorCode::GENERIC_SSL_ERROR: - return ResponseClass::kHttpSLLError; - default: - return ResponseClass::kHttpOtherClientError; - } - } - else if (IsHttpSuccessCode(m_ResponseCode)) - { - return ResponseClass::kSuccess; - } - else - { - switch (m_ResponseCode) - { - case HttpResponseCode::Unauthorized: - return ResponseClass::kHttpUnauthorized; - case HttpResponseCode::NotFound: - return ResponseClass::kHttpNotFound; - case HttpResponseCode::Forbidden: - return ResponseClass::kHttpForbidden; - case HttpResponseCode::Conflict: - return ResponseClass::kHttpConflict; - case HttpResponseCode::InternalServerError: - return ResponseClass::kHttpInternalServerError; - case HttpResponseCode::ServiceUnavailable: - return ResponseClass::kHttpServiceUnavailable; - case HttpResponseCode::BadGateway: - return ResponseClass::kHttpBadGateway; - case HttpResponseCode::GatewayTimeout: - return ResponseClass::kHttpGatewayTimeout; - default: - if (m_ResponseCode >= HttpResponseCode::InternalServerError) - { - return ResponseClass::kHttpOtherServerError; - } - else - { - return ResponseClass::kHttpOtherClientError; - } - } + return HttpClientErrorCode::kOtherError; } } @@ -257,8 +220,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; + .Error = + HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}}; } if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) @@ -526,6 +489,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); } + if (ConnectionSettings.ForbidReuseConnection) + { + CprSession->UpdateHeader({{"Connection", "close"}}); + } if (AccessToken) { CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); @@ -544,6 +511,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, CprSession->SetParameters({}); } + if (!ConnectionSettings.UnixSocketPath.empty()) + { + CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath)); + } + + if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) + { + cpr::SslOptions SslOpts; + if (ConnectionSettings.InsecureSsl) + { + SslOpts.SetOption(cpr::ssl::VerifyHost{false}); + SslOpts.SetOption(cpr::ssl::VerifyPeer{false}); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath}); + } + CprSession->SetSslOptions(SslOpts); + } + ExtendableStringBuilder<128> UrlBuffer; UrlBuffer << BaseUrl << ResourcePath; CprSession->SetUrl(UrlBuffer.c_str()); diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp new file mode 100644 index 000000000..3cb749018 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -0,0 +1,1947 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcurl.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zenhttp/packageformat.h> +#include <algorithm> + +namespace zen { + +HttpClientBase* +CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) +{ + return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); +} + +static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; + +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback helpers + +struct WriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<WriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct HeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + // Trim trailing \r\n + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + // Trim whitespace + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct ReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<ReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +struct StreamReadCallbackData +{ + detail::CompositeBufferReadStream* Reader = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<StreamReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + return Data->Reader->Read(Buffer, MaxRead); +} + +struct FileReadCallbackData +{ + detail::BufferedReadFileStream* Buffer = nullptr; + uint64_t TotalSize = 0; + uint64_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<FileReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->TotalSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + Data->Buffer->Read(Buffer, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +static int +CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr) +{ + ZEN_UNUSED(Handle); + LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); + auto Log = [&]() -> LoggerRef { return LogRef; }; + + std::string_view DataView(Data, Size); + + // Trim trailing newlines + while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n')) + { + DataView.remove_suffix(1); + } + + switch (Type) + { + case CURLINFO_TEXT: + if (DataView.find("need more data"sv) == std::string_view::npos) + { + ZEN_INFO("TEXT: {}", DataView); + } + break; + case CURLINFO_HEADER_IN: + ZEN_INFO("HIN : {}", DataView); + break; + case CURLINFO_HEADER_OUT: + if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos) + { + std::string Copy(DataView); + auto BearerStart = TokenPos + 22; + auto BearerEnd = Copy.find_first_of("\r\n", BearerStart); + if (BearerEnd == std::string::npos) + { + BearerEnd = Copy.length(); + } + Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart)); + ZEN_INFO("HOUT: {}", Copy); + } + else + { + ZEN_INFO("HOUT: {}", DataView); + } + break; + default: + break; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +static curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<HttpClientAccessToken>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + std::string SessionHeader = fmt::format("UE-Session: {}", SessionId); + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken) + { + std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +static std::string +BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) +{ + std::string Url; + Url.reserve(BaseUrl.size() + ResourcePath.size() + 64); + Url.append(BaseUrl); + Url.append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size())); + char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size())); + Url += Separator; + Url += EncodedKey; + Url += '='; + Url += EncodedValue; + curl_free(EncodedKey); + curl_free(EncodedValue); + Separator = '&'; + } + } + + return Url; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::CurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction) +: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)) +{ +} + +CurlHttpClient::~CurlHttpClient() +{ + ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto* Handle : m_Sessions) + { + curl_easy_cleanup(Handle); + } + m_Sessions.clear(); + }); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::Perform() +{ + CurlResult Result; + + char ErrorBuffer[CURL_ERROR_SIZE] = {}; + curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer); + + Result.ErrorCode = curl_easy_perform(Handle); + + if (Result.ErrorCode != CURLE_OK) + { + Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode); + } + + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + Result.ElapsedSeconds = Elapsed; + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + Result.UploadedBytes = static_cast<int64_t>(UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + + return Result; +} + +bool +CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + +HttpClient::Response +CurlHttpClient::ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); + + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + const HttpContentType ContentType = ParseContentType(Value); + ResponseBuffer.SetContentType(ContentType); + break; + } + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + if (ShouldLogErrorCode(WorkResponseCode)) + { + ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}", + SessionId, + static_cast<int>(WorkResponseCode), + m_BaseUri); + } + } + + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Ranges = std::move(BoundaryPositions)}; +} + +HttpClient::Response +CurlHttpClient::CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode); + if (Result.ErrorCode != CURLE_OK) + { + const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); + if (!Quiet) + { + if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && + Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + SessionId, + static_cast<int>(Result.ErrorCode), + Result.ErrorMessage); + } + } + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) + { + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + else + { + return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); + } +} + +bool +CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + + IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + + // Find Content-Length in headers + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Length") + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value); + return false; + } + break; + } + } + + if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent)) + { + return true; + } + + // Check X-Jupiter-IoHash + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "X-Jupiter-IoHash") + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(Value, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; + } + } + break; + } + } + + // Validate content-type specific payload + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + if (Value == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (Value == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; + } + } + break; + } + } + + return true; +} + +bool +CurlHttpClient::ShouldRetry(const CurlResult& Result) +{ + switch (Result.ErrorCode) + { + case CURLE_OK: + break; + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(Result.StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Session +CurlHttpClient::AllocSession(std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader); + ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); + CURL* Handle = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + Handle = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (Handle == nullptr) + { + Handle = curl_easy_init(); + } + else + { + curl_easy_reset(Handle); + } + + // Unix domain socket + if (!ConnectionSettings.UnixSocketPath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str()); + } + + // Build URL with parameters + std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Timeouts + if (ConnectionSettings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count())); + } + if (ConnectionSettings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count())); + } + + // HTTP/2 + if (ConnectionSettings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // Verbose/debug + if (ConnectionSettings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); + curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log); + } + + // SSL options + if (ConnectionSettings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str()); + } + + // Disable signal handling for thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (ConnectionSettings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + + // Note: Headers are NOT set here. Each method builds its own header list + // (potentially adding method-specific headers like Content-Type) and is + // responsible for freeing it with curl_slist_free_all. + + return Session(this, Handle); +} + +void +CurlHttpClient::ReleaseSession(CURL* Handle) +{ + ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); + + // Free any header list that was set + // curl_easy_reset will be called on next AllocSession, which cleans up the handle state. + // We just push the handle back to the pool. + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Response +CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::TransactPackage"); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + std::vector<std::pair<std::string, std::string>> OfferExtraHeaders; + OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); + OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); + + std::string FilterBody; + WriteCallbackData WriteData{.Body = &FilterBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + std::vector<std::pair<std::string, std::string>> PkgExtraHeaders; + PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); + PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); + + std::string PkgBody; + WriteCallbackData WriteData{.Body = &PkgBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + return {.StatusCode = HttpResponseCode(Result.StatusCode)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); + + return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_NOBODY, 1L); + + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE"); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + // Rebuild headers with content type + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + // Track requested content length from Range header (sum all ranges) + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + // Header callback that detects Content-Length and switches to file-backed storage when needed + struct DownloadHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + HttpContentType* ContentTypeOut = nullptr; + }; + + DownloadHeaderCallbackData DlHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + DlHdrData.Headers = &ResponseHeaders; + DlHdrData.PayloadFile = &PayloadFile; + DlHdrData.PayloadString = &PayloadString; + DlHdrData.TempFolderPath = &TempFolderPath; + DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + DlHdrData.Log = m_Log; + DlHdrData.BoundaryParser = &BoundaryParser; + DlHdrData.IsMultiRange = &IsMultiRangeResponse; + DlHdrData.ContentTypeOut = &ContentType; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + else if (Key == "Content-Type"sv) + { + *Data->IsMultiRange = Data->BoundaryParser->Init(Value); + if (!*Data->IsMultiRange) + { + *Data->ContentTypeOut = ParseContentType(Value); + } + } + else if (Key == "Content-Range"sv) + { + if (!*Data->IsMultiRange) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value); + if (Range.second != 0) + { + Data->BoundaryParser->Boundaries.push_back( + HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = *Data->ContentTypeOut}); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData); + + // Write callback that directs data to file or string + struct DownloadWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + }; + + DownloadWriteCallbackData DlWriteData; + DlWriteData.PayloadString = &PayloadString; + DlWriteData.PayloadFile = &PayloadFile; + DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + DlWriteData.TempFolderPath = &TempFolderPath; + DlWriteData.Log = m_Log; + DlWriteData.BoundaryParser = &BoundaryParser; + DlWriteData.IsMultiRange = &IsMultiRangeResponse; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->IsMultiRange) + { + Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes)); + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + // Handle resume logic + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const CurlResult& R) -> bool { + for (const auto& [K, V] : R.Headers) + { + if (K == "Content-Range") + { + return true; + } + if (K == "Accept-Ranges" && V == "bytes") + { + return true; + } + } + return false; + }; + + auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(R)) + { + return SupportsRanges(R); + } + return false; + }; + + if (ShouldResumeCheck(Res)) + { + // Find Content-Length + std::string ContentLengthValue; + for (const auto& [K, V] : Res.Headers) + { + if (K == "Content-Length") + { + ContentLengthValue = V; + break; + } + } + + if (!ContentLengthValue.empty()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + break; // No progress, abort + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session ResumeSess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + CURL* ResumeH = ResumeSess.Get(); + + curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()); + curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList); + curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); + + std::vector<std::pair<std::string, std::string>> ResumeHeaders; + + struct ResumeHeaderCbData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + }; + + ResumeHeaderCbData ResumeHdrData; + ResumeHdrData.Headers = &ResumeHeaders; + ResumeHdrData.PayloadFile = &PayloadFile; + ResumeHdrData.PayloadString = &PayloadString; + + auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<ResumeHeaderCbData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Range"sv) + { + if (Value.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) + { + const std::optional<uint64_t> Start = + ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() + : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) + { + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + } + } + return 0; + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(ResumeH, + CURLOPT_HEADERFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb)); + curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData); + curl_easy_setopt(ResumeH, + CURLOPT_WRITEFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData); + + Res = ResumeSess.Perform(); + Res.Headers = std::move(ResumeHeaders); + + curl_slist_free_all(ResumeHdrList); + } while (ShouldResumeCheck(Res)); + } + } + } + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + curl_slist_free_all(DlHeaders); + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Result), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h new file mode 100644 index 000000000..2a49ff308 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.h @@ -0,0 +1,135 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CurlHttpClient : public HttpClientBase +{ +public: + CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); + ~CurlHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct CurlResult + { + long StatusCode = 0; + std::string Body; + std::vector<std::pair<std::string, std::string>> Headers; + double ElapsedSeconds = 0; + int64_t UploadedBytes = 0; + int64_t DownloadedBytes = 0; + CURLcode ErrorCode = CURLE_OK; + std::string ErrorMessage; + }; + + struct Session + { + Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} + ~Session() { Outer->ReleaseSession(Handle); } + + CURL* Get() const { return Handle; } + + CurlResult Perform(); + + LoggerRef Log() { return Outer->Log(); } + + private: + CurlHttpClient* Outer; + CURL* Handle; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(std::string_view BaseUrl, + std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken); + + RwLock m_SessionLock; + std::vector<CURL*> m_Sessions; + + void ReleaseSession(CURL* Handle); + + struct RetryResult + { + CurlResult Result; + }; + + CurlResult DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + CurlResult DoWithRetry( + std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; }); + + bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + + static bool ShouldRetry(const CurlResult& Result); + + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; + + HttpClient::Response CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); + + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); +}; + +} // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 9497dadb8..792848a6b 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -10,6 +10,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #include <deque> @@ -47,11 +50,7 @@ struct HttpWsClient::Impl m_WorkGuard.reset(); // Close the socket to cancel pending async ops - if (m_Socket) - { - asio::error_code Ec; - m_Socket->close(Ec); - } + CloseSocket(); if (m_IoThread.joinable()) { @@ -59,6 +58,35 @@ struct HttpWsClient::Impl } } + void CloseSocket() + { + asio::error_code Ec; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + m_UnixSocket->close(Ec); + return; + } +#endif + if (m_TcpSocket) + { + m_TcpSocket->close(Ec); + } + } + + template<typename Fn> + void WithSocket(Fn&& Func) + { +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + Func(*m_UnixSocket); + return; + } +#endif + Func(*m_TcpSocket); + } + void ParseUrl(std::string_view Url) { // Expected format: ws://host:port/path @@ -101,9 +129,47 @@ struct HttpWsClient::Impl m_IoThread = std::thread([this] { m_IoContext.run(); }); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_Settings.UnixSocketPath.empty()) + { + asio::post(m_IoContext, [this] { DoConnectUnix(); }); + return; + } +#endif + asio::post(m_IoContext, [this] { DoResolve(); }); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoConnectUnix() + { + m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath); + CloseSocket(); + } + }); + + asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath); + m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } +#endif + void DoResolve() { m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); @@ -122,7 +188,7 @@ struct HttpWsClient::Impl void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) { - m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); // Start connect timeout timer m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); @@ -130,15 +196,11 @@ struct HttpWsClient::Impl if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) { ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); - if (m_Socket) - { - asio::error_code CloseEc; - m_Socket->close(CloseEc); - } + CloseSocket(); } }); - asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { if (Ec) { m_Timer->cancel(); @@ -194,64 +256,68 @@ struct HttpWsClient::Impl m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); - asio::async_write(*m_Socket, - asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), - [this](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - m_Timer->cancel(); - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); - m_Handler.OnWsClose(1006, "handshake write failed"); - return; - } - - DoReadHandshakeResponse(); - }); + WithSocket([this](auto& Socket) { + asio::async_write(Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + }); } void DoReadHandshakeResponse() { - asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { - m_Timer->cancel(); + WithSocket([this](auto& Socket) { + asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); - if (Ec) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); - m_Handler.OnWsClose(1006, "handshake read failed"); - return; - } + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } - // Parse the response - const auto& Data = m_ReadBuffer.data(); - std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); - // Consume the headers from the read buffer (any extra data stays for frame parsing) - auto HeaderEnd = Response.find("\r\n\r\n"); - if (HeaderEnd != std::string::npos) - { - m_ReadBuffer.consume(HeaderEnd + 4); - } + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } - // Validate 101 response - if (Response.find("101") == std::string::npos) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); - m_Handler.OnWsClose(1006, "handshake rejected"); - return; - } + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } - // Validate Sec-WebSocket-Accept - std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); - if (Response.find(ExpectedAccept) == std::string::npos) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); - m_Handler.OnWsClose(1006, "invalid accept key"); - return; - } + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } - m_IsOpen.store(true); - m_Handler.OnWsOpen(); - EnqueueRead(); + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); }); } @@ -267,8 +333,10 @@ struct HttpWsClient::Impl return; } - asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { - OnDataReceived(Ec); + WithSocket([this](auto& Socket) { + asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); }); } @@ -414,9 +482,11 @@ struct HttpWsClient::Impl auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); - asio::async_write(*m_Socket, - asio::buffer(OwnedFrame->data(), OwnedFrame->size()), - [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + WithSocket([this, OwnedFrame](auto& Socket) { + asio::async_write(Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + }); } void OnWriteComplete(const asio::error_code& Ec) @@ -501,11 +571,14 @@ struct HttpWsClient::Impl // Connection state std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<asio::steady_timer> m_Timer; - asio::streambuf m_ReadBuffer; - std::string m_WebSocketKey; - std::shared_ptr<std::string> m_HandshakeBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket; +#endif + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; // Write queue RwLock m_WriteLock; |