diff options
| author | Stefan Boberg <[email protected]> | 2026-03-16 10:56:11 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-16 10:56:11 +0100 |
| commit | 8c3ba4e8c522d119df3cb48966e36c0eaa80aeb9 (patch) | |
| tree | cf51b07e097904044b4bf65bc3fe0ad14134074f /src/zenhttp | |
| parent | Merge branch 'sb/no-network' of https://github.ol.epicgames.net/ue-foundation... (diff) | |
| parent | Enable cross compilation of Windows targets on Linux (#839) (diff) | |
| download | zen-sb/no-network.tar.xz zen-sb/no-network.zip | |
Merge branch 'main' into sb/no-networksb/no-network
Diffstat (limited to 'src/zenhttp')
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 3 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 1100 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.h | 25 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 4 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 19 | ||||
| -rw-r--r-- | src/zenhttp/httpclient_test.cpp | 107 | ||||
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 290 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/cprutils.h | 22 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 14 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpserver.h | 59 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpwsclient.h | 2 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 180 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 2 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 2 | ||||
| -rw-r--r-- | src/zenhttp/xmake.lua | 7 |
15 files changed, 741 insertions, 1095 deletions
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index a0f5cc38f..a52b8f74b 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -7,6 +7,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> @@ -513,7 +514,7 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, if (!ConnectionSettings.UnixSocketPath.empty()) { - CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath)); + CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath))); } if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index 341adc5f7..ec9b7bac6 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -7,6 +7,8 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> @@ -93,15 +95,11 @@ struct HeaderCallbackData std::vector<std::pair<std::string, std::string>>* Headers = nullptr; }; -static size_t -CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +static std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) { - auto* Data = static_cast<HeaderCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - std::string_view Line(Buffer, TotalBytes); - - // Trim trailing \r\n while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) { Line.remove_suffix(1); @@ -109,25 +107,39 @@ CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) if (Line.empty()) { - return TotalBytes; + return std::nullopt; } size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) + if (ColonPos == std::string_view::npos) { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); + return std::nullopt; + } - // Trim whitespace - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; Data->Headers->emplace_back(std::string(Key), std::string(Value)); } @@ -285,57 +297,102 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, for (const auto& [Key, Value] : *AdditionalHeader) { - std::string HeaderLine = fmt::format("{}: {}", Key, Value); - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); } if (!SessionId.empty()) { - std::string SessionHeader = fmt::format("UE-Session: {}", SessionId); - Headers = curl_slist_append(Headers, SessionHeader.c_str()); + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); } if (AccessToken) { - std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken->Value; + Headers = curl_slist_append(Headers, AuthHeader.c_str()); } for (const auto& [Key, Value] : ExtraHeaders) { - std::string HeaderLine = fmt::format("{}: {}", Key, Value); - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); } return Headers; } -static std::string -BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) +static HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +static void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +static void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +static void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) { - std::string Url; - Url.reserve(BaseUrl.size() + ResourcePath.size() + 64); - Url.append(BaseUrl); - Url.append(ResourcePath); + Url.Append(BaseUrl); + Url.Append(ResourcePath); if (!Parameters->empty()) { char Separator = '?'; for (const auto& [Key, Value] : *Parameters) { - char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size())); - char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size())); - Url += Separator; - Url += EncodedKey; - Url += '='; - Url += EncodedValue; - curl_free(EncodedKey); - curl_free(EncodedValue); + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); Separator = '&'; } } - - return Url; } ////////////////////////////////////////////////////////////////////////// @@ -359,6 +416,48 @@ CurlHttpClient::~CurlHttpClient() }); } +CurlHttpClient::Session::~Session() +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + Outer->ReleaseSession(Handle); +} + +void +CurlHttpClient::Session::SetHeaders(curl_slist* Headers) +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + HeaderList = Headers; + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::PerformWithResponseCallbacks() +{ + std::string Body; + WriteCallbackData WriteData{.Body = &Body, + .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + return Result; +} + CurlHttpClient::CurlResult CurlHttpClient::Session::Perform() { @@ -411,15 +510,7 @@ CurlHttpClient::ResponseWithPayload(std::string_view SessionId, { IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); - for (const auto& [Key, Value] : Result.Headers) - { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) - { - const HttpContentType ContentType = ParseContentType(Value); - ResponseBuffer.SetContentType(ContentType); - break; - } - } + ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers); if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { @@ -438,15 +529,9 @@ CurlHttpClient::ResponseWithPayload(std::string_view SessionId, return Lhs.RangeOffset < Rhs.RangeOffset; }); - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds, @@ -475,16 +560,10 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, } } - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{ .StatusCode = WorkResponseCode, .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds, @@ -493,14 +572,8 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) { - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{.StatusCode = WorkResponseCode, - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds}; @@ -519,25 +592,43 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); - // Find Content-Length in headers + // Collect relevant headers in a single pass + std::string_view ContentLengthValue; + std::string_view IoHashValue; + std::string_view ContentTypeValue; + for (const auto& [Key, Value] : Result.Headers) { - if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0) { - std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); - if (!ExpectedContentSize.has_value()) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value); - return false; - } - if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value); - return false; - } - break; + ContentLengthValue = Value; + } + else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0) + { + IoHashValue = Value; + } + else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0) + { + ContentTypeValue = Value; + } + } + + // Validate Content-Length + if (!ContentLengthValue.empty()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue); + return false; } } @@ -546,66 +637,55 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp return true; } - // Check X-Jupiter-IoHash - for (const auto& [Key, Value] : Result.Headers) + // Validate X-Jupiter-IoHash + if (!IoHashValue.empty()) { - if (StrCaseCompare(Key.c_str(), "X-Jupiter-IoHash") == 0) + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash)) { - IoHash ExpectedPayloadHash; - if (IoHash::TryParse(Value, ExpectedPayloadHash)) + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) { - IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); - if (PayloadHash != ExpectedPayloadHash) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", - PayloadHash.ToHexString(), - ExpectedPayloadHash.ToHexString()); - return false; - } + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; } - break; } } // Validate content-type specific payload - for (const auto& [Key, Value] : Result.Headers) + if (ContentTypeValue == "application/x-ue-comp") { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) { - if (Value == "application/x-ue-comp") - { - IoHash RawHash; - uint64_t RawSize; - if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, - RawHash, - RawSize, - /*OutOptionalTotalCompressedSize*/ nullptr)) - { - return true; - } - else - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = "Compressed binary failed validation"; - return false; - } - } - if (Value == "application/x-ue-cb") - { - if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); - Error == CbValidateError::None) - { - return true; - } - else - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); - return false; - } - } - break; + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (ContentTypeValue == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; } } @@ -666,10 +746,24 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult Attempt++; if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) { - ZEN_INFO("{} Attempt {}/{}", - CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), - Attempt, - m_ConnectionSettings.RetryCount + 1); + if (Result.ErrorCode != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + SessionId, + static_cast<int>(MapCurlError(Result.ErrorCode)), + Result.ErrorMessage, + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + SessionId, + Result.StatusCode, + zen::ToString(HttpResponseCode(Result.StatusCode)), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } } Result = Func(); } @@ -681,51 +775,14 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) { - uint8_t Attempt = 0; - CurlResult Result = Func(); - while (Attempt < m_ConnectionSettings.RetryCount) - { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return Result; - } - if (!ShouldRetry(Result)) - { - if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) - { - break; - } - if (ValidatePayload(Result, PayloadFile)) - { - break; - } - } - Sleep(100 * (Attempt + 1)); - Attempt++; - if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) - { - ZEN_INFO("{} Attempt {}/{}", - CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), - Attempt, - m_ConnectionSettings.RetryCount + 1); - } - Result = Func(); - } - return Result; + return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); }); } ////////////////////////////////////////////////////////////////////////// CurlHttpClient::Session -CurlHttpClient::AllocSession(std::string_view BaseUrl, - std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) +CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters) { - ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader); ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); CURL* Handle = nullptr; m_SessionLock.WithExclusiveLock([&] { @@ -739,6 +796,10 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, if (Handle == nullptr) { Handle = curl_easy_init(); + if (Handle == nullptr) + { + ThrowOutOfMemory("curl_easy_init"); + } } else { @@ -746,33 +807,35 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, } // Unix domain socket - if (!ConnectionSettings.UnixSocketPath.empty()) + if (!m_ConnectionSettings.UnixSocketPath.empty()) { - curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str()); + std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str()); } // Build URL with parameters - std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters); + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); // Timeouts - if (ConnectionSettings.ConnectTimeout.count() > 0) + if (m_ConnectionSettings.ConnectTimeout.count() > 0) { - curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count())); + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count())); } - if (ConnectionSettings.Timeout.count() > 0) + if (m_ConnectionSettings.Timeout.count() > 0) { - curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count())); + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count())); } // HTTP/2 - if (ConnectionSettings.AssumeHttp2) + if (m_ConnectionSettings.AssumeHttp2) { curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); } // Verbose/debug - if (ConnectionSettings.Verbose) + if (m_ConnectionSettings.Verbose) { curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); @@ -780,27 +843,27 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, } // SSL options - if (ConnectionSettings.InsecureSsl) + if (m_ConnectionSettings.InsecureSsl) { curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); } - if (!ConnectionSettings.CaBundlePath.empty()) + if (!m_ConnectionSettings.CaBundlePath.empty()) { - curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str()); + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str()); } // Disable signal handling for thread safety curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); - if (ConnectionSettings.ForbidReuseConnection) + if (m_ConnectionSettings.ForbidReuseConnection) { curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); } // Note: Headers are NOT set here. Each method builds its own header list - // (potentially adding method-specific headers like Content-Type) and is - // responsible for freeing it with curl_slist_free_all. + // (potentially adding method-specific headers like Content-Type) and passes + // ownership to the Session via SetHeaders(). return Session(this, Handle); } @@ -809,15 +872,13 @@ void CurlHttpClient::ReleaseSession(CURL* Handle) { ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); - - // Free any header list that was set - // curl_easy_reset will be called on next AllocSession, which cleans up the handle state. - // We just push the handle back to the pool. m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); } ////////////////////////////////////////////////////////////////////////// +// TransactPackage is a two-phase protocol (offer + send) with server-side state +// between phases, so retrying individual phases would be incorrect. CurlHttpClient::Response CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) { @@ -831,7 +892,7 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; auto RequestIdString = fmt::to_string(RequestId); - if (Attachments.empty() == false) + if (!Attachments.empty()) { CbObjectWriter Writer; Writer.BeginArray("offer"); @@ -850,27 +911,19 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders)); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); - std::string FilterBody; - WriteCallbackData WriteData{.Body = &FilterBody}; - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - - CurlResult Result = Sess.Perform(); - - curl_slist_free_all(HeaderList); + CurlResult Result = Sess.PerformWithResponseCallbacks(); - if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200) + if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode)) { - IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size()); + IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); CbValidateError ValidationError = CbValidateError::None; if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); ValidationError == CbValidateError::None) @@ -908,41 +961,17 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders)); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); - std::string PkgBody; - WriteCallbackData WriteData{.Body = &PkgBody}; - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - - CurlResult Result = Sess.Perform(); - - curl_slist_free_all(HeaderList); + CurlResult Result = Sess.PerformWithResponseCallbacks(); - if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) - { - return {.StatusCode = HttpResponseCode(Result.StatusCode)}; - } - - IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); - - for (const auto& [Key, Value] : Result.Headers) - { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) - { - ResponseBuffer.SetContentType(ParseContentType(Value)); - break; - } - } - - return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = std::move(ResponseBuffer)}; + return CommonResponse(m_SessionId, std::move(Result), {}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -957,44 +986,26 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu return CommonResponse( m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - - curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; - curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - curl_slist_free_all(Headers); + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - return Result; - }), + return Sess.PerformWithResponseCallbacks(); + }), {}); } @@ -1005,39 +1016,19 @@ CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) return CommonResponse( m_SessionId, - DoWithRetry( - m_SessionId, - [&]() -> CurlResult { - KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - - curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = AllocSession(Url, Parameters); + CURL* H = Sess.Get(); - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); + Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken())); - curl_slist_free_all(Headers); + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); - return Result; - }), + return Sess.PerformWithResponseCallbacks(); + }), {}); } @@ -1045,43 +1036,20 @@ CurlHttpClient::Response CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { ZEN_TRACE_CPU("CurlHttpClient::Get"); - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }, - [this](CurlResult& Result) { - std::unique_ptr<detail::TempPayloadFile> NoTempFile; - return ValidatePayload(Result, NoTempFile); - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L); + return Sess.PerformWithResponseCallbacks(); + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); } CurlHttpClient::Response @@ -1089,33 +1057,15 @@ CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CurlHttpClient::Head"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_NOBODY, 1L); - - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1123,38 +1073,15 @@ CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader { ZEN_TRACE_CPU("CurlHttpClient::Delete"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE"); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE"); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1162,39 +1089,16 @@ CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, { ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_POST, 1L); - curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L); + curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1213,12 +1117,10 @@ CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTy DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - // Rebuild headers with content type - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); IoBufferFileReference FileRef = {nullptr, 0, 0}; if (Payload.GetFileReference(FileRef)) @@ -1234,46 +1136,14 @@ CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTy curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); } curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1295,12 +1165,11 @@ CurlHttpClient::Post(std::string_view Url, PayloadString.clear(); PayloadFile.reset(); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)})); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); @@ -1329,33 +1198,11 @@ CurlHttpClient::Post(std::string_view Url, auto* Data = static_cast<PostHeaderCallbackData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) { - Line.remove_suffix(1); - } - - if (Line.empty()) - { - return TotalBytes; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) - { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } + auto& [Key, Value] = *Header; - if (StrCaseCompare(std::string(Key).c_str(), "Content-Length") == 0) + if (StrCaseCompare(Key, "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Value); if (ContentLength.has_value()) @@ -1444,7 +1291,6 @@ CurlHttpClient::Post(std::string_view Url, Res.Body = std::move(PayloadString); } - curl_slist_free_all(Headers); return Res; }, PayloadFile); @@ -1467,13 +1313,10 @@ CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenCo m_SessionId, DoWithRetry(m_SessionId, [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); @@ -1485,23 +1328,7 @@ CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenCo curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1516,12 +1343,11 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); @@ -1538,23 +1364,7 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); } ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), @@ -1563,23 +1373,7 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1596,13 +1390,10 @@ CurlHttpClient::Upload(std::string_view Url, m_SessionId, DoWithRetry(m_SessionId, [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); @@ -1615,23 +1406,7 @@ CurlHttpClient::Upload(std::string_view Url, curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1651,11 +1426,10 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp CurlResult Result = DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); // Reset state from any previous attempt @@ -1673,7 +1447,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp { std::string_view RangeValue(RangeIt->second); size_t RangeStartPos = RangeValue.find('=', 5); - if (RangeStartPos != std::string::npos) + if (RangeStartPos != std::string_view::npos) { RangeStartPos++; while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') @@ -1685,14 +1459,14 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp while (RangeStartPos < RangeValue.length()) { size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); - if (RangeEnd == std::string::npos) + if (RangeEnd == std::string_view::npos) { RangeEnd = RangeValue.length(); } std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); size_t RangeSplitPos = RangeString.find('-'); - if (RangeSplitPos != std::string::npos) + if (RangeSplitPos != std::string_view::npos) { std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); @@ -1742,36 +1516,12 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) { - return TotalBytes; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) - { - std::string_view KeyView = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!KeyView.empty() && KeyView.back() == ' ') - { - KeyView.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - + auto& [KeyView, Value] = *Header; const std::string Key(KeyView); - if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + if (StrCaseCompare(Key, "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Value); if (ContentLength.has_value()) @@ -1795,7 +1545,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } } } - else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + else if (StrCaseCompare(Key, "Content-Type") == 0) { *Data->IsMultiRange = Data->BoundaryParser->Init(Value); if (!*Data->IsMultiRange) @@ -1803,7 +1553,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp *Data->ContentTypeOut = ParseContentType(Value); } } - else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) + else if (StrCaseCompare(Key, "Content-Range") == 0) { if (!*Data->IsMultiRange) { @@ -1819,7 +1569,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } } - Data->Headers->emplace_back(std::string(Key), std::string(Value)); + Data->Headers->emplace_back(Key, std::string(Value)); } return TotalBytes; @@ -1894,11 +1644,11 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto SupportsRanges = [](const CurlResult& R) -> bool { for (const auto& [K, V] : R.Headers) { - if (StrCaseCompare(K.c_str(), "Content-Range") == 0) + if (StrCaseCompare(K, "Content-Range") == 0) { return true; } - if (StrCaseCompare(K.c_str(), "Accept-Ranges") == 0) + if (StrCaseCompare(K, "Accept-Ranges") == 0) { return V == "bytes"sv; } @@ -1924,7 +1674,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::string ContentLengthValue; for (const auto& [K, V] : Res.Headers) { - if (StrCaseCompare(K.c_str(), "Content-Length") == 0) + if (StrCaseCompare(K, "Content-Length") == 0) { ContentLengthValue = V; break; @@ -1943,6 +1693,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } KeyValueMap HeadersWithRange(AdditionalHeader); + uint8_t ResumeAttempt = 0; do { uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); @@ -1957,12 +1708,10 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } HeadersWithRange.Entries.insert_or_assign("Range", Range); - Session ResumeSess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - CURL* ResumeH = ResumeSess.Get(); + Session ResumeSess = AllocSession(Url, {}); + CURL* ResumeH = ResumeSess.Get(); - curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()); - curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList); + ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken())); curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); std::vector<std::pair<std::string, std::string>> ResumeHeaders; @@ -1983,72 +1732,51 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto* Data = static_cast<ResumeHeaderCbData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) + auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)); + if (!Header) { return TotalBytes; } + auto& [Key, Value] = *Header; - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) + if (StrCaseCompare(Key, "Content-Range") == 0) { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - while (!Key.empty() && Key.back() == ' ') + if (Value.starts_with("bytes "sv)) { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - - if (StrCaseCompare(std::string(Key).c_str(), "Content-Range") == 0) - { - if (Value.starts_with("bytes "sv)) + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) { - size_t RangeStartEnd = Value.find('-', 6); - if (RangeStartEnd != std::string_view::npos) + const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) { - const std::optional<uint64_t> Start = - ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); - if (Start) + uint64_t DownloadedSize = + *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) { - uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() - : Data->PayloadString->length(); - if (Start.value() == DownloadedSize) - { - Data->Headers->emplace_back(std::string(Key), std::string(Value)); - return TotalBytes; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (*Data->PayloadFile) - { - (*Data->PayloadFile)->ResetWritePos(Start.value()); - } - else - { - *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); - } Data->Headers->emplace_back(std::string(Key), std::string(Value)); return TotalBytes; } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; } } - return 0; } - - Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return 0; } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); return TotalBytes; }; @@ -2064,8 +1792,8 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp Res = ResumeSess.Perform(); Res.Headers = std::move(ResumeHeaders); - curl_slist_free_all(ResumeHdrList); - } while (ShouldResumeCheck(Res)); + ResumeAttempt++; + } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res)); } } } @@ -2075,8 +1803,6 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp Res.Body = std::move(PayloadString); } - curl_slist_free_all(DlHeaders); - return Res; }, PayloadFile); diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index 871877863..b7fa52e6c 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -75,40 +75,39 @@ private: struct Session { Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} - ~Session() { Outer->ReleaseSession(Handle); } + ~Session(); CURL* Get() const { return Handle; } + // Takes ownership of the curl_slist and sets it on the handle. + // The list is freed automatically when the Session is destroyed. + void SetHeaders(curl_slist* Headers); + + // Low-level perform: executes the request and collects status/timing. CurlResult Perform(); + // Sets up standard write+header callbacks, performs the request, and + // moves the collected body and headers into the returned CurlResult. + CurlResult PerformWithResponseCallbacks(); + LoggerRef Log() { return Outer->Log(); } private: CurlHttpClient* Outer; CURL* Handle; + curl_slist* HeaderList = nullptr; Session(Session&&) = delete; Session& operator=(Session&&) = delete; }; - Session AllocSession(std::string_view BaseUrl, - std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); + Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters); RwLock m_SessionLock; std::vector<CURL*> m_Sessions; void ReleaseSession(CURL* Handle); - struct RetryResult - { - CurlResult Result; - }; - CurlResult DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 2d566ae86..fbae9f5fe 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -5,6 +5,8 @@ #include "../servers/wsframecodec.h" #include <zencore/base64.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/string.h> @@ -155,7 +157,7 @@ struct HttpWsClient::Impl } }); - asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath); + asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath)); m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { if (Ec) { diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index deeeb6c85..9f49802a0 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -36,15 +36,17 @@ namespace zen { +#if ZEN_WITH_CPR extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); +#endif extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); -static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCpr; +static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCurl; void SetDefaultHttpClientBackend(HttpClientBackend Backend) @@ -55,11 +57,14 @@ SetDefaultHttpClientBackend(HttpClientBackend Backend) void SetDefaultHttpClientBackend(std::string_view Backend) { +#if ZEN_WITH_CPR if (Backend == "cpr") { g_DefaultHttpClientBackend = HttpClientBackend::kCpr; } - else if (Backend == "curl") + else +#endif + if (Backend == "curl") { g_DefaultHttpClientBackend = HttpClientBackend::kCurl; } @@ -363,13 +368,15 @@ HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& Conne switch (EffectiveBackend) { - case HttpClientBackend::kCurl: - m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); - break; +#if ZEN_WITH_CPR case HttpClientBackend::kCpr: - default: m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); break; +#endif + case HttpClientBackend::kCurl: + default: + m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; } } diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 5f3ad2455..3ca586f87 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -154,6 +154,42 @@ public: }, HttpVerb::kGet); + m_Router.AddMatcher("anypath", [](std::string_view Str) -> bool { return !Str.empty(); }); + + m_Router.RegisterRoute( + "echo/uri", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string Body = std::string(HttpReq.RelativeUri()); + + auto Params = HttpReq.GetQueryParams(); + for (const auto& [Key, Value] : Params.KvPairs) + { + Body += fmt::format("\n{}={}", Key, Value); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Body); + }, + HttpVerb::kGet | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/uri/{anypath}", + [](HttpRouterRequest& Req) { + // Echo both the RelativeUri and the captured path segment + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Captured = Req.GetCapture(1); + std::string Body = fmt::format("uri={}\ncapture={}", HttpReq.RelativeUri(), Captured); + + auto Params = HttpReq.GetQueryParams(); + for (const auto& [Key, Value] : Params.KvPairs) + { + Body += fmt::format("\n{}={}", Key, Value); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Body); + }, + HttpVerb::kGet | HttpVerb::kPut); + m_Router.RegisterRoute( "slow", [](HttpRouterRequest& Req) { @@ -1689,6 +1725,77 @@ TEST_CASE("httpclient.https") # endif // ZEN_USE_OPENSSL +TEST_CASE("httpclient.uri_decoding") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + // URI without encoding — should pass through unchanged + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt"); + } + + // Percent-encoded space — server should see decoded path + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt"); + } + + // Percent-encoded slash (%2F) — should be decoded to / + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/a/b.txt\ncapture=a/b.txt"); + } + + // Multiple encodings in one path + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/file%20%26%20name.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt"); + } + + // No capture — echo/uri route returns just RelativeUri + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "echo/uri"); + } + + // Literal percent that is not an escape (%ZZ) — should be kept as-is + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt"); + } + + // Query params — raw values are returned as-is from GetQueryParams + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "echo/uri\nkey=value\nname=test"); + } + + // Query params with percent-encoded values + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3"); + REQUIRE(Resp.IsSuccess()); + // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly + CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3"); + } + + // Query params with path capture and encoding + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt?tag=a%26b"); + REQUIRE(Resp.IsSuccess()); + // Path is decoded, query values are raw + CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt\ntag=a%26b"); + } +} + TEST_SUITE_END(); void diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 6ba0ca563..4d98e9650 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -2,6 +2,8 @@ #include <zenhttp/httpserver.h> +#include <zencore/filesystem.h> + #include "servers/httpasio.h" #include "servers/httpmulti.h" #include "servers/httpnull.h" @@ -698,15 +700,6 @@ HttpServerRequest::ReadPayloadPackage() ////////////////////////////////////////////////////////////////////////// void -HttpRequestRouter::AddPattern(const char* Id, const char* Regex) -{ - ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); - ZEN_ASSERT(!m_IsFinalized); - - m_PatternMap.insert({Id, Regex}); -} - -void HttpRequestRouter::AddMatcher(const char* Id, std::function<bool(std::string_view)>&& Matcher) { ZEN_ASSERT(m_MatcherNameMap.find(Id) == m_MatcherNameMap.end()); @@ -722,170 +715,77 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { ZEN_ASSERT(!m_IsFinalized); - if (ExtendableStringBuilder<128> ExpandedRegex; ProcessRegexSubstitutions(UriPattern, ExpandedRegex)) - { - // Regex route - m_RegexHandlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), UriPattern); - } - else - { - // New-style regex-free route. More efficient and should be used for everything eventually - - int RegexLen = gsl::narrow_cast<int>(strlen(UriPattern)); + int RegexLen = gsl::narrow_cast<int>(strlen(UriPattern)); - int i = 0; + int i = 0; - std::vector<int> MatcherIndices; + std::vector<int> MatcherIndices; - while (i < RegexLen) + while (i < RegexLen) + { + if (UriPattern[i] == '{') { - if (UriPattern[i] == '{') + bool IsComplete = false; + int PatternStart = i + 1; + while (++i < RegexLen) { - bool IsComplete = false; - int PatternStart = i + 1; - while (++i < RegexLen) + if (UriPattern[i] == '}') { - if (UriPattern[i] == '}') + if (i == PatternStart) { - if (i == PatternStart) - { - throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); - } - std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); - if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) - { - // It's a match - MatcherIndices.push_back(it->second); - IsComplete = true; - ++i; - break; - } - else - { - throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); - } + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); } - } - if (!IsComplete) - { - throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); - } - } - else - { - if (UriPattern[i] == '/') - { - throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); - } - - int SegmentStart = i; - while (++i < RegexLen && UriPattern[i] != '/') - ; - - std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); - int LiteralIndex = gsl::narrow_cast<int>(m_Literals.size()); - m_Literals.push_back(std::string(Segment)); - MatcherIndices.push_back(-1 - LiteralIndex); - } - - if (i < RegexLen && UriPattern[i] == '/') - { - ++i; // skip slash - } - } - - m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); - } -} - -std::string_view -HttpRouterRequest::GetCapture(uint32_t Index) const -{ - if (!m_CapturedSegments.empty()) - { - ZEN_ASSERT(Index < m_CapturedSegments.size()); - return m_CapturedSegments[Index]; - } - - ZEN_ASSERT(Index < m_Match.size()); - - const auto& Match = m_Match[Index]; - - return std::string_view(&*Match.first, Match.second - Match.first); -} - -bool -HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) -{ - size_t RegexLen = strlen(Regex); - - bool HasRegex = false; - - std::vector<std::string> UnknownPatterns; - - for (size_t i = 0; i < RegexLen;) - { - bool matched = false; - - if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) - { - // Might have a pattern reference - find closing brace - - for (size_t j = i + 1; j < RegexLen; ++j) - { - if (Regex[j] == '}') - { - std::string Pattern(&Regex[i + 1], j - i - 1); - - if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) + std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); + if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { - OutExpandedRegex.Append(it->second.c_str()); - HasRegex = true; + // It's a match + MatcherIndices.push_back(it->second); + IsComplete = true; + ++i; + break; } else { - UnknownPatterns.push_back(Pattern); + throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); } - - // skip ahead - i = j + 1; - - matched = true; - - break; } } + if (!IsComplete) + { + throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); + } } - - if (!matched) - { - OutExpandedRegex.Append(Regex[i++]); - } - } - - if (HasRegex) - { - if (UnknownPatterns.size() > 0) + else { - std::string UnknownList; - for (const auto& Pattern : UnknownPatterns) + if (UriPattern[i] == '/') { - if (!UnknownList.empty()) - { - UnknownList += ", "; - } - UnknownList += "'"; - UnknownList += Pattern; - UnknownList += "'"; + throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); } - throw std::runtime_error(fmt::format("unknown pattern(s) {} in regex route '{}'", UnknownList, Regex)); + int SegmentStart = i; + while (++i < RegexLen && UriPattern[i] != '/') + ; + + std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); + int LiteralIndex = gsl::narrow_cast<int>(m_Literals.size()); + m_Literals.push_back(std::string(Segment)); + MatcherIndices.push_back(-1 - LiteralIndex); } - return true; + if (i < RegexLen && UriPattern[i] == '/') + { + ++i; // skip slash + } } - return false; + m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); +} + +std::string_view +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + ZEN_ASSERT(Index < m_CapturedSegments.size()); + return m_CapturedSegments[Index]; } bool @@ -901,8 +801,6 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) std::string_view Uri = Request.RelativeUri(); HttpRouterRequest RouterRequest(Request); - // First try new-style matcher routes - for (const MatcherEndpoint& Handler : m_MatcherEndpoints) { if ((Handler.Verbs & Verb) == Verb) @@ -1000,28 +898,6 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) } } - // Old-style regex routes - - for (const auto& Handler : m_RegexHandlers) - { - if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) - { -#if ZEN_WITH_OTEL - if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) - { - ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.Service().BaseUri()); - RoutePath.Append(Handler.Pattern); - ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); - } -#endif - - Handler.Handler(RouterRequest); - - return true; // Route matched - } - } - return false; // No route matched } @@ -1157,7 +1033,7 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig ZEN_INFO("using asio HTTP server implementation") return CreateHttpAsioServer(AsioConfig { .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer, - .NoNetwork = Config.NoNetwork, .UnixSocketPath = Config.UnixSocketPath, + .NoNetwork = Config.NoNetwork, .UnixSocketPath = PathToUtf8(Config.UnixSocketPath), #if ZEN_USE_OPENSSL .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile, #endif @@ -1420,72 +1296,6 @@ TEST_CASE("http.common") virtual uint32_t ParseRequestId() const override { return 0; } }; - SUBCASE("router-regex") - { - bool HandledA = false; - bool HandledAA = false; - std::vector<std::string> Captures; - auto Reset = [&] { - Captures.clear(); - HandledA = HandledAA = false; - }; - - TestHttpService Service; - - HttpRequestRouter r; - r.AddPattern("a", "([[:alpha:]]+)"); - r.RegisterRoute( - "{a}", - [&](auto& Req) { - HandledA = true; - Captures = {std::string(Req.GetCapture(0))}; - }, - HttpVerb::kGet); - - r.RegisterRoute( - "{a}/{a}", - [&](auto& Req) { - HandledAA = true; - Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; - }, - HttpVerb::kGet); - - { - Reset(); - TestHttpServerRequest req(Service, "abc"sv); - r.HandleRequest(req); - CHECK(HandledA); - CHECK(!HandledAA); - REQUIRE_EQ(Captures.size(), 1); - CHECK_EQ(Captures[0], "abc"sv); - } - - { - Reset(); - TestHttpServerRequest req{Service, "abc/def"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - CHECK(HandledAA); - REQUIRE_EQ(Captures.size(), 2); - CHECK_EQ(Captures[0], "abc"sv); - CHECK_EQ(Captures[1], "def"sv); - } - - { - Reset(); - TestHttpServerRequest req{Service, "123"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - } - - { - Reset(); - TestHttpServerRequest req{Service, "a123"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - } - } - SUBCASE("router-matcher") { bool HandledA = false; diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index c252a5d99..3cfe652c5 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -2,17 +2,19 @@ #pragma once -#include <zencore/compactbinary.h> -#include <zencore/compactbinaryvalidation.h> -#include <zencore/iobuffer.h> -#include <zencore/string.h> -#include <zenhttp/formatters.h> -#include <zenhttp/httpclient.h> -#include <zenhttp/httpcommon.h> +#if ZEN_WITH_CPR + +# include <zencore/compactbinary.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/iobuffer.h> +# include <zencore/string.h> +# include <zenhttp/formatters.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpcommon.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/response.h> -#include <fmt/format.h> +# include <cpr/response.h> +# include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END template<> @@ -92,3 +94,5 @@ struct fmt::formatter<cpr::Response> } } }; + +#endif // ZEN_WITH_CPR diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 03c98af7e..e878c900f 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -10,6 +10,7 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <filesystem> #include <functional> #include <optional> #include <unordered_map> @@ -51,7 +52,9 @@ enum class HttpClientErrorCode : int enum class HttpClientBackend : uint8_t { kDefault, +#if ZEN_WITH_CPR kCpr, +#endif kCurl, }; @@ -91,7 +94,7 @@ struct HttpClientSettings /// Unix domain socket path. When non-empty, the client connects via this /// socket instead of TCP. BaseUri is still used for the Host header and URL. - std::string UnixSocketPath; + std::filesystem::path UnixSocketPath; /// Disable HTTP keep-alive by closing the connection after each request. /// Useful for testing per-connection overhead. @@ -174,11 +177,14 @@ class HttpClientBase; class HttpClient { public: - HttpClient(std::string_view BaseUri, - const HttpClientSettings& Connectionsettings = {}, - std::function<bool()>&& CheckIfAbortFunction = {}); + explicit HttpClient(std::string_view BaseUri, + const HttpClientSettings& Connectionsettings = {}, + std::function<bool()>&& CheckIfAbortFunction = {}); ~HttpClient(); + HttpClient(const HttpClient&) = delete; + HttpClient& operator=(const HttpClient&) = delete; + struct ErrorContext { HttpClientErrorCode ErrorCode; diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 627e7921f..2a8b2ca94 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -15,11 +15,11 @@ #include <zentelemetry/stats.h> +#include <filesystem> #include <functional> #include <gsl/gsl-lite.hpp> #include <list> #include <map> -#include <regex> #include <span> #include <unordered_map> @@ -329,7 +329,7 @@ struct HttpServerConfig std::vector<HttpServerPluginConfig> PluginConfigs; bool ForceLoopback = false; unsigned int ThreadCount = 0; - std::string UnixSocketPath; // Unix domain socket path (empty = disabled, non-Windows only) + std::filesystem::path UnixSocketPath; // Unix domain socket path (empty = disabled) bool NoNetwork = false; // Disable TCP/HTTPS listeners; only accept connections via UnixSocketPath int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend) std::string CertFile; // PEM certificate chain file path @@ -356,9 +356,8 @@ class HttpRouterRequest public: /** Get captured segment from matched URL * - * @param Index Index of captured segment to retrieve. Note that due to - * backwards compatibility with regex-based routes, this index is 1-based - * and index=0 is the full matched URL + * @param Index Index of captured segment to retrieve. Index 0 is the full + * matched URL, subsequent indices are the matched segments in order. * @return Returns string view of captured segment */ std::string_view GetCapture(uint32_t Index) const; @@ -371,11 +370,8 @@ private: HttpRouterRequest(const HttpRouterRequest&) = delete; HttpRouterRequest& operator=(const HttpRouterRequest&) = delete; - using MatchResults_t = std::match_results<std::string_view::const_iterator>; - HttpServerRequest& m_HttpRequest; - MatchResults_t m_Match; - std::vector<std::string_view> m_CapturedSegments; // for matcher-based routes + std::vector<std::string_view> m_CapturedSegments; friend class HttpRequestRouter; }; @@ -383,9 +379,7 @@ private: /** HTTP request router helper * * This helper class allows a service implementer to register one or more - * endpoints using pattern matching. We currently support a legacy regex-based - * matching system, but also a new matcher-function based system which is more - * efficient and should be used whenever possible. + * endpoints using pattern matching with matcher functions. * * This is intended to be initialized once only, there is no thread * safety so you can absolutely not add or remove endpoints once the handler @@ -404,13 +398,6 @@ public: typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; /** - * @brief Add pattern which can be referenced by name, commonly used for URL components - * @param Id String used to identify patterns for replacement - * @param Regex String which will replace the Id string in any registered URL paths - */ - void AddPattern(const char* Id, const char* Regex); - - /** * @brief Add matcher function which can be referenced by name, used for URL components * @param Id String used to identify matchers in endpoint specifications * @param Matcher Function which will be called to match the component @@ -420,8 +407,8 @@ public: /** * @brief Register an endpoint handler for the given route * @param Pattern Pattern used to match the handler to a request. This should - * only contain literal URI segments and pattern aliases registered - via AddPattern() or AddMatcher() + * only contain literal URI segments and matcher aliases registered + via AddMatcher() * @param HandlerFunc Handler function to call for any matching request * @param SupportedVerbs Supported HTTP verbs for this handler */ @@ -436,36 +423,6 @@ public: bool HandleRequest(zen::HttpServerRequest& Request); private: - bool ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); - - struct RegexEndpoint - { - RegexEndpoint(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) - : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) - , Verbs(SupportedVerbs) - , Handler(std::move(Handler)) - , Pattern(Pattern) - { - } - - ~RegexEndpoint() = default; - - std::regex RegEx; - HttpVerb Verbs; - HandlerFunc_t Handler; - const char* Pattern; - - private: - RegexEndpoint& operator=(const RegexEndpoint&) = delete; - RegexEndpoint(const RegexEndpoint&) = delete; - }; - - std::list<RegexEndpoint> m_RegexHandlers; - std::unordered_map<std::string, std::string> m_PatternMap; - - // New-style matcher endpoints. Should be preferred over regex endpoints where possible - // as it is considerably more efficient - struct MatcherEndpoint { MatcherEndpoint(std::vector<int>&& ComponentIndices, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 34d338b1d..2ca9b7ab1 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -46,7 +46,7 @@ struct HttpWsClientSettings /// Unix domain socket path. When non-empty, connects via this socket /// instead of TCP. The URL host is still used for the Host header. - std::string UnixSocketPath; + std::filesystem::path UnixSocketPath; }; /** diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 643f33618..9f4875eaf 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -601,6 +601,7 @@ public: bool m_IsLocalMachineRequest; bool m_AllowZeroCopyFileSend = true; std::string m_RemoteAddress; + std::string m_DecodedUri; // Percent-decoded URI; m_Uri/m_UriWithExtension point into this std::unique_ptr<HttpResponse> m_Response; }; @@ -623,6 +624,7 @@ public: ~HttpResponse() = default; void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } + void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } /** * Initialize the response for sending a payload made up of multiple blobs @@ -780,8 +782,8 @@ public: return m_Headers; } - template<typename SocketType> - void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) + template<typename SocketType, typename Executor> + void SendResponse(SocketType& Socket, Executor& Strand, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) { ZEN_ASSERT(m_State == State::kInitialized); @@ -791,11 +793,11 @@ public: m_SendCb = std::move(Token); m_State = State::kSending; - SendNextChunk(Socket); + SendNextChunk(Socket, Strand); } - template<typename SocketType> - void SendNextChunk(SocketType& Socket) + template<typename SocketType, typename Executor> + void SendNextChunk(SocketType& Socket, Executor& Strand) { ZEN_ASSERT(m_State == State::kSending); @@ -812,12 +814,12 @@ public: auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); }; - asio::defer(Socket.get_executor(), std::move(CompletionToken)); + asio::defer(Strand, std::move(CompletionToken)); return; } - auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) { + auto OnCompletion = asio::bind_executor(Strand, [this, &Socket, &Strand](const asio::error_code& Ec, std::size_t ByteCount) { ZEN_ASSERT(m_State == State::kSending); m_TotalBytesSent += ByteCount; @@ -828,9 +830,9 @@ public: } else { - SendNextChunk(Socket); + SendNextChunk(Socket, Strand); } - }; + }); const IoVec& Io = m_IoVecs[m_IoVecCursor++]; @@ -982,16 +984,14 @@ private: void CloseConnection(); void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); - HttpAsioServerImpl& m_Server; - asio::streambuf m_RequestBuffer; - std::atomic<uint32_t> m_RequestCounter{0}; - uint32_t m_ConnectionId = 0; - Ref<IHttpPackageHandler> m_PackageHandler; - - RwLock m_ActiveResponsesLock; + HttpAsioServerImpl& m_Server; + std::unique_ptr<SocketType> m_Socket; + asio::strand<asio::any_io_executor> m_Strand; + asio::streambuf m_RequestBuffer; + uint32_t m_RequestCounter = 0; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses; - - std::unique_ptr<SocketType> m_Socket; }; std::atomic<uint32_t> g_ConnectionIdCounter{0}; @@ -999,8 +999,9 @@ std::atomic<uint32_t> g_ConnectionIdCounter{0}; template<typename SocketType> HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket) : m_Server(Server) -, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) , m_Socket(std::move(Socket)) +, m_Strand(asio::make_strand(m_Socket->get_executor())) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) { ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } @@ -1008,8 +1009,6 @@ HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Ser template<typename SocketType> HttpServerConnectionT<SocketType>::~HttpServerConnectionT() { - RwLock::ExclusiveLockScope _(m_ActiveResponsesLock); - ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } @@ -1017,7 +1016,7 @@ template<typename SocketType> void HttpServerConnectionT<SocketType>::HandleNewRequest() { - EnqueueRead(); + asio::dispatch(m_Strand, [Conn = AsSharedPtr()] { Conn->EnqueueRead(); }); } template<typename SocketType> @@ -1058,7 +1057,9 @@ HttpServerConnectionT<SocketType>::EnqueueRead() asio::async_read(*m_Socket.get(), m_RequestBuffer, asio::transfer_at_least(1), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); + asio::bind_executor(m_Strand, [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnDataReceived(Ec, ByteCount); + })); } template<typename SocketType> @@ -1091,7 +1092,7 @@ HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[ ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), + m_RequestCounter, zen::GetCurrentThreadId(), NiceBytes(ByteCount)); @@ -1153,25 +1154,23 @@ HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& if (ResponseToPop) { - m_ActiveResponsesLock.WithExclusiveLock([&] { - // Once a response is sent we can release any referenced resources - // - // completion callbacks may be issued out-of-order so we need to - // remove the relevant entry from our active response list, it may - // not be the first - - if (auto It = find_if(begin(m_ActiveResponses), - end(m_ActiveResponses), - [ResponseToPop](const auto& Item) { return Item.get() == ResponseToPop; }); - It != end(m_ActiveResponses)) - { - m_ActiveResponses.erase(It); - } - else - { - ZEN_WARN("response not found"); - } - }); + // Once a response is sent we can release any referenced resources + // + // completion callbacks may be issued out-of-order so we need to + // remove the relevant entry from our active response list, it may + // not be the first + + if (auto It = find_if(begin(m_ActiveResponses), + end(m_ActiveResponses), + [ResponseToPop](const auto& Item) { return Item.get() == ResponseToPop; }); + It != end(m_ActiveResponses)) + { + m_ActiveResponses.erase(It); + } + else + { + ZEN_WARN("response not found"); + } } if (!m_RequestData.IsKeepAlive()) @@ -1234,9 +1233,11 @@ HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber asio::async_write( *m_Socket, Buffer, - [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + })); } template<typename SocketType> @@ -1272,21 +1273,23 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::async_write( *m_Socket, asio::buffer(ResponseStr->data(), ResponseStr->size()), - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); - return; - } - - Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); - using WsConnType = WsAsioConnectionT<SocketType>; - Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); - Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); - WsConn->Start(); - }); + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + using WsConnType = WsAsioConnectionT<SocketType>; + Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + })); m_RequestState = RequestState::kDone; return; @@ -1312,7 +1315,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() m_RequestState = RequestState::kWriting; } - const uint32_t RequestNumber = m_RequestCounter.fetch_add(1); + const uint32_t RequestNumber = m_RequestCounter++; if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) { @@ -1444,31 +1447,34 @@ HttpServerConnectionT<SocketType>::HandleRequest() { ZEN_TRACE_CPU("asio::async_write"); - std::string_view Headers = Response->GetHeaders(); + HttpResponse* ResponseRaw = Response.get(); + m_ActiveResponses.push_back(std::move(Response)); + + std::string_view Headers = ResponseRaw->GetHeaders(); std::vector<asio::const_buffer> AsioBuffers; AsioBuffers.push_back(asio::const_buffer(Headers.data(), Headers.size())); - asio::async_write(*m_Socket.get(), - AsioBuffers, - asio::transfer_all(), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + asio::async_write( + *m_Socket.get(), + AsioBuffers, + asio::transfer_all(), + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), ResponseRaw, RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ ResponseRaw); + })); } else { ZEN_TRACE_CPU("asio::async_write"); HttpResponse* ResponseRaw = Response.get(); - - m_ActiveResponsesLock.WithExclusiveLock([&] { - // Keep referenced resources alive - m_ActiveResponses.push_back(std::move(Response)); - }); + m_ActiveResponses.push_back(std::move(Response)); ResponseRaw->SendResponse( *m_Socket, + m_Strand, [Conn = AsSharedPtr(), ResponseRaw, RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ ResponseRaw); }); @@ -1982,11 +1988,24 @@ HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, { const int PrefixLength = Service.UriPrefixLength(); - std::string_view Uri = Request.Url(); - Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); - m_Uri = Uri; - m_UriWithExtension = Uri; - m_QueryString = Request.QueryString(); + std::string_view RawUri = Request.Url(); + RawUri.remove_prefix(std::min(PrefixLength, static_cast<int>(RawUri.size()))); + + // Percent-decode the URI path so handlers see the same decoded paths regardless + // of whether the ASIO or http.sys backend is used (http.sys pre-decodes via CookedUrl). + // Skip the allocation when there is nothing to decode (common case). + if (RawUri.find('%') != std::string_view::npos) + { + m_DecodedUri = Decode(RawUri); + m_Uri = m_DecodedUri; + m_UriWithExtension = m_DecodedUri; + } + else + { + m_Uri = RawUri; + m_UriWithExtension = RawUri; + } + m_QueryString = Request.QueryString(); m_Verb = Request.RequestVerb(); m_ContentLength = Request.Body().Size(); @@ -2083,6 +2102,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2097,6 +2117,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -2108,6 +2129,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 4bf8c61bb..31b0315d4 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -147,7 +147,7 @@ public: HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection - virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual bool IsLocalMachineRequest() const override { return false; } virtual std::string_view GetAuthorizationHeader() const override; virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4d6a53696..f8fb1c9be 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -1173,7 +1173,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) { Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - if ((Result == ERROR_SHARING_VIOLATION)) + if (Result == ERROR_SHARING_VIOLATION) { ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 9b461662e..b4c65ea96 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -8,7 +8,12 @@ target('zenhttp') add_files("servers/httpsys.cpp", {unity_ignored=true}) add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) - add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") + add_deps("zencore", "zentelemetry", "transport-sdk", "asio") + if has_config("zencpr") then + add_deps("cpr") + else + remove_files("clients/httpclientcpr.cpp") + end add_packages("http_parser", "json11") add_options("httpsys") |