diff options
Diffstat (limited to 'src/zenhttp/clients')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.cpp | 380 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 120 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 791 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 34 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 1816 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.h | 137 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 641 |
7 files changed, 3590 insertions, 329 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 47425e014..e4d11547a 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -142,7 +142,10 @@ namespace detail { DataSize -= CopySize; if (m_CacheBufferOffset == CacheBufferSize) { - AppendData(m_CacheBuffer, CacheBufferSize); + if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize)) + { + return Ec; + } if (DataSize > 0) { ZEN_ASSERT(DataSize < CacheBufferSize); @@ -382,6 +385,177 @@ namespace detail { return Result; } + MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {} + + bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue) + { + std::string LowerCaseValue = ToLower(ContentTypeHeaderValue); + if (LowerCaseValue.starts_with("multipart/byteranges")) + { + size_t BoundaryPos = LowerCaseValue.find("boundary="); + if (BoundaryPos != std::string::npos) + { + // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string + std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + size_t BoundaryEnd = std::string::npos; + while (!BoundaryName.empty() && BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } + if (!BoundaryName.empty()) + { + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + { + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) + { + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); + return true; + } + } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } + } + } + } + return false; + } + + void MultipartBoundaryParser::ParseInput(std::string_view data) + { + const char* InputPtr = data.data(); + size_t InputLength = data.length(); + size_t ScanPos = 0; + while (ScanPos < InputLength) + { + const char ScanChar = InputPtr[ScanPos]; + if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) + { + BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + break; + } + } + + BoundaryHeader.Append(ScanChar); + + HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + + if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset(); + const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset(); + const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset; + std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength)); + + uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1; + + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType = HttpContentType::kBinary; + + ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) { + const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line); + const std::string_view Key = KeyAndValue.first; + const std::string_view Value = KeyAndValue.second; + if (Key == "Content-Range") + { + std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value); + if (ContentRange.second != 0) + { + RangeOffset = ContentRange.first; + RangeLength = ContentRange.second; + } + } + else if (Key == "Content-Type") + { + ContentType = ParseContentType(Value); + } + + return true; + }); + + if (RangeLength > 0) + { + Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload, + .RangeOffset = RangeOffset, + .RangeLength = RangeLength, + .ContentType = ContentType}); + } + + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + } + } + else + { + BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar); + } + ScanPos++; + } + PayloadOffset += InputLength; + } + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString) + { + size_t DelimiterPos = HeaderString.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string_view Key = HeaderString.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string_view Value = HeaderString.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + return std::make_pair(Key, Value); + } + return std::make_pair(HeaderString, std::string_view{}); + } + + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value) + { + if (Value.starts_with("bytes ")) + { + size_t RangeSplitPos = Value.find('-', 6); + if (RangeSplitPos != std::string::npos) + { + size_t RangeEndLength = Value.find('/', RangeSplitPos + 1); + if (RangeEndLength == std::string::npos) + { + RangeEndLength = Value.length() - (RangeSplitPos + 1); + } + else + { + RangeEndLength = RangeEndLength - (RangeSplitPos + 1); + } + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + uint64_t RangeOffset = RequestedRangeStart.value(); + uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1; + return std::make_pair(RangeOffset, RangeLength); + } + } + } + return {0, 0}; + } + } // namespace detail } // namespace zen @@ -423,6 +597,8 @@ namespace testutil { } // namespace testutil +TEST_SUITE_BEGIN("http.httpclientcommon"); + TEST_CASE("BufferedReadFileStream") { ScopedTemporaryDirectory TmpDir; @@ -470,5 +646,207 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("ParseContentRange") +{ + SUBCASE("normal range with total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500"); + CHECK_EQ(Offset, 0); + CHECK_EQ(Length, 100); + } + + SUBCASE("non-zero offset") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878"); + CHECK_EQ(Offset, 2638); + CHECK_EQ(Length, 5111437 - 2638 + 1); + } + + SUBCASE("wildcard total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*"); + CHECK_EQ(Offset, 100); + CHECK_EQ(Length, 100); + } + + SUBCASE("no slash (total size omitted)") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 50-149"); + CHECK_EQ(Offset, 50); + CHECK_EQ(Length, 100); + } + + SUBCASE("malformed input returns zeros") + { + auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500"); + CHECK_EQ(Offset1, 0); + CHECK_EQ(Length1, 0); + + auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500"); + CHECK_EQ(Offset2, 0); + CHECK_EQ(Length2, 0); + + auto [Offset3, Length3] = detail::ParseContentRange(""); + CHECK_EQ(Offset3, 0); + CHECK_EQ(Length3, 0); + + auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500"); + CHECK_EQ(Offset4, 0); + CHECK_EQ(Length4, 0); + } + + SUBCASE("single byte range") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000"); + CHECK_EQ(Offset, 42); + CHECK_EQ(Length, 1); + } +} + +TEST_CASE("MultipartBoundaryParser") +{ + uint64_t Range1Offset = 2638; + uint64_t Range1Length = (5111437 - Range1Offset) + 1; + + uint64_t Range2Offset = 5118199; + uint64_t Range2Length = (9147741 - Range2Offset) + 1; + + std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229"; + std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\""; + + { + std::string_view Example1 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/44369878\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample1; + ParserExample1.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 7; + for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow) + { + ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow))); + } + + CHECK(ParserExample1.Boundaries.size() == 2); + + CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example2 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample2; + ParserExample2.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow) + { + std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow)); + ParserExample2.ParseInput(Window); + } + + CHECK(ParserExample2.Boundaries.size() == 2); + + CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example3 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita"; + + detail::MultipartBoundaryParser ParserExample3; + ParserExample3.Init(ContentTypeHeaderValue2); + + const size_t InputWindow = 31; + for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow) + { + ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow))); + } + + CHECK(ParserExample3.Boundaries.size() == 2); + + CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example4 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "Not: really\r\n" + "\r\n" + "datadatadatadata" + "\r\n--000000000bait0019229\r\n" + "\r\n--00\r\n--000000000bait001922\r\n" + "\r\n\r\n\r\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "Content-Type: application/x-ue-comp\r\n" + "ditaditadita" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n---\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample4; + ParserExample4.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow) + { + std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow)); + ParserExample4.ParseInput(Window); + } + + CHECK(ParserExample4.Boundaries.size() == 2); + + CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length); + } +} + +TEST_SUITE_END(); + } // namespace zen #endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 1d0b7f9ea..e95e3a253 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/compositebuffer.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhttp/httpclient.h> @@ -35,7 +36,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) = 0; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -87,7 +91,7 @@ namespace detail { std::error_code Write(std::string_view DataString); IoBuffer DetachToIoBuffer(); IoBuffer BorrowIoBuffer(); - inline uint64_t GetSize() const { return m_WriteOffset; } + inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; } void ResetWritePos(uint64_t WriteOffset); private: @@ -143,6 +147,118 @@ namespace detail { uint64_t m_BytesLeftInSegment; }; + class IncrementalStringMatcher + { + public: + enum class EMatchState + { + None, + Partial, + Complete + }; + + EMatchState MatchState = EMatchState::None; + + IncrementalStringMatcher() {} + + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) + { + RawMatchString = MatchString.data(); + } + + void Init(std::string&& InMatchString) + { + MatchString = std::move(InMatchString); + RawMatchString = MatchString.data(); + } + + inline void Reset() + { + MatchLength = 0; + MatchStartOffset = 0; + MatchState = EMatchState::None; + } + + inline uint64_t GetMatchEndOffset() const + { + if (MatchState == EMatchState::Complete) + { + return MatchStartOffset + MatchString.length(); + } + return 0; + } + + inline uint64_t GetMatchStartOffset() const + { + ZEN_ASSERT(MatchState == EMatchState::Complete); + return MatchStartOffset; + } + + void Match(uint64_t Offset, char C) + { + ZEN_ASSERT_SLOW(RawMatchString != nullptr); + + if (MatchState == EMatchState::Complete) + { + Reset(); + } + if (C == RawMatchString[MatchLength]) + { + if (MatchLength == 0) + { + MatchStartOffset = Offset; + } + MatchLength++; + if (MatchLength == MatchString.length()) + { + MatchState = EMatchState::Complete; + } + else + { + MatchState = EMatchState::Partial; + } + } + else if (MatchLength != 0) + { + Reset(); + Match(Offset, C); + } + else + { + Reset(); + } + } + inline const std::string& GetMatchString() const { return MatchString; } + + private: + std::string MatchString; + const char* RawMatchString = nullptr; + uint64_t MatchLength = 0; + + uint64_t MatchStartOffset = 0; + }; + + class MultipartBoundaryParser + { + public: + std::vector<HttpClient::Response::MultipartBoundary> Boundaries; + + MultipartBoundaryParser(); + bool Init(const std::string_view ContentTypeHeaderValue); + void ParseInput(std::string_view data); + + private: + IncrementalStringMatcher BoundaryBeginMatcher; + IncrementalStringMatcher BoundaryEndMatcher; + IncrementalStringMatcher HeaderEndMatcher; + + ExtendableStringBuilder<64> BoundaryHeader; + uint64_t PayloadOffset = 0; + }; + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString); + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value); + } // namespace detail } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 5d92b3b6b..a52b8f74b 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -7,11 +7,18 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> #include <zencore/stream.h> #include <zenhttp/packageformat.h> +#include <algorithm> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/ssl_options.h> +#include <cpr/unix_socket.h> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { @@ -23,69 +30,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; -// If we want to support different HTTP client implementations then we'll need to make this more abstract +////////////////////////////////////////////////////////////////////////// -HttpClientError::ResponseClass -HttpClientError::GetResponseClass() const +static HttpClientErrorCode +MapCprError(cpr::ErrorCode Code) { - if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) - { - switch ((cpr::ErrorCode)m_Error) - { - case cpr::ErrorCode::CONNECTION_FAILURE: - return ResponseClass::kHttpCantConnectError; - case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: - case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return ResponseClass::kHttpNoHost; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return ResponseClass::kHttpTimeout; - case cpr::ErrorCode::SSL_CONNECT_ERROR: - case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_CACERT_ERROR: - case cpr::ErrorCode::GENERIC_SSL_ERROR: - return ResponseClass::kHttpSLLError; - default: - return ResponseClass::kHttpOtherClientError; - } - } - else if (IsHttpSuccessCode(m_ResponseCode)) + switch (Code) { - return ResponseClass::kSuccess; - } - else - { - switch (m_ResponseCode) - { - case HttpResponseCode::Unauthorized: - return ResponseClass::kHttpUnauthorized; - case HttpResponseCode::NotFound: - return ResponseClass::kHttpNotFound; - case HttpResponseCode::Forbidden: - return ResponseClass::kHttpForbidden; - case HttpResponseCode::Conflict: - return ResponseClass::kHttpConflict; - case HttpResponseCode::InternalServerError: - return ResponseClass::kHttpInternalServerError; - case HttpResponseCode::ServiceUnavailable: - return ResponseClass::kHttpServiceUnavailable; - case HttpResponseCode::BadGateway: - return ResponseClass::kHttpBadGateway; - case HttpResponseCode::GatewayTimeout: - return ResponseClass::kHttpGatewayTimeout; - default: - if (m_ResponseCode >= HttpResponseCode::InternalServerError) - { - return ResponseClass::kHttpOtherServerError; - } - else - { - return ResponseClass::kHttpOtherClientError; - } - } + case cpr::ErrorCode::OK: + return HttpClientErrorCode::kOK; + case cpr::ErrorCode::CONNECTION_FAILURE: + return HttpClientErrorCode::kConnectionFailure; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + return HttpClientErrorCode::kHostResolutionFailure; + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return HttpClientErrorCode::kProxyResolutionFailure; + case cpr::ErrorCode::INTERNAL_ERROR: + return HttpClientErrorCode::kInternalError; + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + return HttpClientErrorCode::kNetworkSendFailure; + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + return HttpClientErrorCode::kSSLCertificateError; + case cpr::ErrorCode::SSL_CACERT_ERROR: + return HttpClientErrorCode::kSSLCACertError; + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return HttpClientErrorCode::kGenericSSLError; + case cpr::ErrorCode::REQUEST_CANCELLED: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; } } @@ -149,6 +129,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri, { } +bool +CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + // Quiet + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); @@ -162,10 +154,11 @@ CprHttpClient::~CprHttpClient() } HttpClient::Response -CprHttpClient::ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload) +CprHttpClient::ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly @@ -174,30 +167,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); - ResponseBuffer.SetContentType(ContentType); } - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - - if (!Quiet) + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + if (ShouldLogErrorCode(WorkResponseCode)) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } } + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + .ElapsedSeconds = HttpResponse.elapsed, + .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response -CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload) +CprHttpClient::CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) @@ -221,8 +221,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; + .Error = + HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}}; } if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) @@ -235,7 +235,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe } else { - return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload)); + return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } @@ -346,8 +346,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -385,8 +384,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -492,6 +490,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); } + if (ConnectionSettings.ForbidReuseConnection) + { + CprSession->UpdateHeader({{"Connection", "close"}}); + } if (AccessToken) { CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); @@ -510,6 +512,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, CprSession->SetParameters({}); } + if (!ConnectionSettings.UnixSocketPath.empty()) + { + CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath))); + } + + if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) + { + cpr::SslOptions SslOpts; + if (ConnectionSettings.InsecureSsl) + { + SslOpts.SetOption(cpr::ssl::VerifyHost{false}); + SslOpts.SetOption(cpr::ssl::VerifyPeer{false}); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath}); + } + CprSession->SetSslOptions(SslOpts); + } + ExtendableStringBuilder<128> UrlBuffer; UrlBuffer << BaseUrl << ResourcePath; CprSession->SetUrl(UrlBuffer.c_str()); @@ -621,7 +643,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke ResponseBuffer.SetContentType(ContentType); } - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -774,22 +796,97 @@ CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTyp } CprHttpClient::Response -CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +CprHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); - return CommonResponse( + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + cpr::Response Response = DoWithRetry( m_SessionId, - DoWithRetry(m_SessionId, - [&]() { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + [&]() { + PayloadString.clear(); + PayloadFile.reset(); - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); - return Sess.Post(); - }), - {}); + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + cpr::Response Response = Sess.Post({}, cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); } CprHttpClient::Response @@ -896,236 +993,292 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::string PayloadString; std::unique_ptr<detail::TempPayloadFile> PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return false; - } - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional<size_t> RequestedRangeStart = - ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) - { - PayloadFile = std::make_unique<detail::TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile); - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (IsMultiRangeResponse) + { + BoundaryParser.ParseInput(data); + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + if (RequestedContentLength != (uint64_t)-1 && RequestedContentLength > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + ZEN_DEBUG("Multirange request"); + } + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + const std::string Key(Header.first); + if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + { + IsMultiRangeResponse = BoundaryParser.Init(Header.second); + if (!IsMultiRangeResponse) + { + ContentType = ParseContentType(Header.second); + } + } + else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) + { + if (!IsMultiRangeResponse) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second); + if (Range.second != 0) + { + BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = ContentType}); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Range") == 0) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Response), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); } } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 40af53b5d..009e6fb7a 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -38,7 +38,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) override; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -104,15 +107,27 @@ private: CprSession->SetReadCallback({}); return Result; } - inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}, + std::optional<cpr::WriteCallback>&& Write = {}, + std::optional<cpr::HeaderCallback>&& Header = {}) { ZEN_TRACE_CPU("HttpClient::Impl::Post"); if (Read) { CprSession->SetReadCallback(std::move(Read.value())); } + if (Write) + { + CprSession->SetWriteCallback(std::move(Write.value())); + } + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } cpr::Response Result = CprSession->Post(); ZEN_TRACE("POST {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); CprSession->SetReadCallback({}); return Result; } @@ -155,14 +170,19 @@ private: std::function<cpr::Response()>&& Func, std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }); + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); - HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload); + HttpClient::Response CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); - HttpClient::Response ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload); + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); }; } // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp new file mode 100644 index 000000000..ec9b7bac6 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -0,0 +1,1816 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcurl.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zenhttp/packageformat.h> +#include <algorithm> + +namespace zen { + +HttpClientBase* +CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) +{ + return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); +} + +static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; + +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback helpers + +struct WriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<WriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct HeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +static std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) +{ + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return std::nullopt; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct ReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<ReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +struct StreamReadCallbackData +{ + detail::CompositeBufferReadStream* Reader = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<StreamReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + return Data->Reader->Read(Buffer, MaxRead); +} + +struct FileReadCallbackData +{ + detail::BufferedReadFileStream* Buffer = nullptr; + uint64_t TotalSize = 0; + uint64_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<FileReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->TotalSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + Data->Buffer->Read(Buffer, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +static int +CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr) +{ + ZEN_UNUSED(Handle); + LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); + auto Log = [&]() -> LoggerRef { return LogRef; }; + + std::string_view DataView(Data, Size); + + // Trim trailing newlines + while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n')) + { + DataView.remove_suffix(1); + } + + switch (Type) + { + case CURLINFO_TEXT: + if (DataView.find("need more data"sv) == std::string_view::npos) + { + ZEN_INFO("TEXT: {}", DataView); + } + break; + case CURLINFO_HEADER_IN: + ZEN_INFO("HIN : {}", DataView); + break; + case CURLINFO_HEADER_OUT: + if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos) + { + std::string Copy(DataView); + auto BearerStart = TokenPos + 22; + auto BearerEnd = Copy.find_first_of("\r\n", BearerStart); + if (BearerEnd == std::string::npos) + { + BearerEnd = Copy.length(); + } + Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart)); + ZEN_INFO("HOUT: {}", Copy); + } + else + { + ZEN_INFO("HOUT: {}", DataView); + } + break; + default: + break; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +static curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<HttpClientAccessToken>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken) + { + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken->Value; + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +static HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +static void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +static void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +static void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) +{ + Url.Append(BaseUrl); + Url.Append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); + Separator = '&'; + } + } +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::CurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction) +: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)) +{ +} + +CurlHttpClient::~CurlHttpClient() +{ + ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto* Handle : m_Sessions) + { + curl_easy_cleanup(Handle); + } + m_Sessions.clear(); + }); +} + +CurlHttpClient::Session::~Session() +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + Outer->ReleaseSession(Handle); +} + +void +CurlHttpClient::Session::SetHeaders(curl_slist* Headers) +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + HeaderList = Headers; + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::PerformWithResponseCallbacks() +{ + std::string Body; + WriteCallbackData WriteData{.Body = &Body, + .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::Perform() +{ + CurlResult Result; + + char ErrorBuffer[CURL_ERROR_SIZE] = {}; + curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer); + + Result.ErrorCode = curl_easy_perform(Handle); + + if (Result.ErrorCode != CURLE_OK) + { + Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode); + } + + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + Result.ElapsedSeconds = Elapsed; + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + Result.UploadedBytes = static_cast<int64_t>(UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + + return Result; +} + +bool +CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + +HttpClient::Response +CurlHttpClient::ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); + + ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers); + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + if (ShouldLogErrorCode(WorkResponseCode)) + { + ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}", + SessionId, + static_cast<int>(WorkResponseCode), + m_BaseUri); + } + } + + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Ranges = std::move(BoundaryPositions)}; +} + +HttpClient::Response +CurlHttpClient::CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode); + if (Result.ErrorCode != CURLE_OK) + { + const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); + if (!Quiet) + { + if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && + Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + SessionId, + static_cast<int>(Result.ErrorCode), + Result.ErrorMessage); + } + } + + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + else + { + return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); + } +} + +bool +CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + + IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + + // Collect relevant headers in a single pass + std::string_view ContentLengthValue; + std::string_view IoHashValue; + std::string_view ContentTypeValue; + + for (const auto& [Key, Value] : Result.Headers) + { + if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0) + { + ContentLengthValue = Value; + } + else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0) + { + IoHashValue = Value; + } + else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0) + { + ContentTypeValue = Value; + } + } + + // Validate Content-Length + if (!ContentLengthValue.empty()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue); + return false; + } + } + + if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent)) + { + return true; + } + + // Validate X-Jupiter-IoHash + if (!IoHashValue.empty()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; + } + } + } + + // Validate content-type specific payload + if (ContentTypeValue == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (ContentTypeValue == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; + } + } + + return true; +} + +bool +CurlHttpClient::ShouldRetry(const CurlResult& Result) +{ + switch (Result.ErrorCode) + { + case CURLE_OK: + break; + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(Result.StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + if (Result.ErrorCode != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + SessionId, + static_cast<int>(MapCurlError(Result.ErrorCode)), + Result.ErrorMessage, + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + SessionId, + Result.StatusCode, + zen::ToString(HttpResponseCode(Result.StatusCode)), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + } + Result = Func(); + } + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); }); +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Session +CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); + CURL* Handle = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + Handle = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (Handle == nullptr) + { + Handle = curl_easy_init(); + if (Handle == nullptr) + { + ThrowOutOfMemory("curl_easy_init"); + } + } + else + { + curl_easy_reset(Handle); + } + + // Unix domain socket + if (!m_ConnectionSettings.UnixSocketPath.empty()) + { + std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str()); + } + + // Build URL with parameters + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Timeouts + if (m_ConnectionSettings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count())); + } + if (m_ConnectionSettings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count())); + } + + // HTTP/2 + if (m_ConnectionSettings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // Verbose/debug + if (m_ConnectionSettings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); + curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log); + } + + // SSL options + if (m_ConnectionSettings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!m_ConnectionSettings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str()); + } + + // Disable signal handling for thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (m_ConnectionSettings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + + // Note: Headers are NOT set here. Each method builds its own header list + // (potentially adding method-specific headers like Content-Type) and passes + // ownership to the Session via SetHeaders(). + + return Session(this, Handle); +} + +void +CurlHttpClient::ReleaseSession(CURL* Handle) +{ + ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); +} + +////////////////////////////////////////////////////////////////////////// + +// TransactPackage is a two-phase protocol (offer + send) with server-side state +// between phases, so retrying individual phases would be incorrect. +CurlHttpClient::Response +CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::TransactPackage"); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (!Attachments.empty()) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + std::vector<std::pair<std::string, std::string>> OfferExtraHeaders; + OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); + OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders)); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); + + CurlResult Result = Sess.PerformWithResponseCallbacks(); + + if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode)) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + std::vector<std::pair<std::string, std::string>> PkgExtraHeaders; + PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); + PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders)); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); + + CurlResult Result = Sess.PerformWithResponseCallbacks(); + + return CommonResponse(m_SessionId, std::move(Result), {}, {}); +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = AllocSession(Url, Parameters); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken())); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Get"); + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L); + return Sess.PerformWithResponseCallbacks(); + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Head"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Delete"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE"); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L); + curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + } + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + PayloadString.clear(); + PayloadFile.reset(); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)})); + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + + struct PostHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + }; + + PostHeaderCallbackData PostHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + PostHdrData.Headers = &ResponseHeaders; + PostHdrData.PayloadFile = &PayloadFile; + PostHdrData.PayloadString = &PayloadString; + PostHdrData.TempFolderPath = &TempFolderPath; + PostHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + PostHdrData.Log = m_Log; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + + if (StrCaseCompare(Key, "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (!Data->TempFolderPath->empty() && ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &PostHdrData); + + struct PostWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + }; + + PostWriteCallbackData PostWriteData; + PostWriteData.PayloadString = &PayloadString; + PostWriteData.PayloadFile = &PayloadFile; + PostWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + PostWriteData.TempFolderPath = &TempFolderPath; + PostWriteData.Log = m_Log; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &PostWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, std::move(Result), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + } + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + // Track requested content length from Range header (sum all ranges) + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string_view::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string_view::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string_view::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + // Header callback that detects Content-Length and switches to file-backed storage when needed + struct DownloadHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + HttpContentType* ContentTypeOut = nullptr; + }; + + DownloadHeaderCallbackData DlHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + DlHdrData.Headers = &ResponseHeaders; + DlHdrData.PayloadFile = &PayloadFile; + DlHdrData.PayloadString = &PayloadString; + DlHdrData.TempFolderPath = &TempFolderPath; + DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + DlHdrData.Log = m_Log; + DlHdrData.BoundaryParser = &BoundaryParser; + DlHdrData.IsMultiRange = &IsMultiRangeResponse; + DlHdrData.ContentTypeOut = &ContentType; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [KeyView, Value] = *Header; + const std::string Key(KeyView); + + if (StrCaseCompare(Key, "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + else if (StrCaseCompare(Key, "Content-Type") == 0) + { + *Data->IsMultiRange = Data->BoundaryParser->Init(Value); + if (!*Data->IsMultiRange) + { + *Data->ContentTypeOut = ParseContentType(Value); + } + } + else if (StrCaseCompare(Key, "Content-Range") == 0) + { + if (!*Data->IsMultiRange) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value); + if (Range.second != 0) + { + Data->BoundaryParser->Boundaries.push_back( + HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = *Data->ContentTypeOut}); + } + } + } + + Data->Headers->emplace_back(Key, std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData); + + // Write callback that directs data to file or string + struct DownloadWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + }; + + DownloadWriteCallbackData DlWriteData; + DlWriteData.PayloadString = &PayloadString; + DlWriteData.PayloadFile = &PayloadFile; + DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + DlWriteData.TempFolderPath = &TempFolderPath; + DlWriteData.Log = m_Log; + DlWriteData.BoundaryParser = &BoundaryParser; + DlWriteData.IsMultiRange = &IsMultiRangeResponse; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->IsMultiRange) + { + Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes)); + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + // Handle resume logic + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const CurlResult& R) -> bool { + for (const auto& [K, V] : R.Headers) + { + if (StrCaseCompare(K, "Content-Range") == 0) + { + return true; + } + if (StrCaseCompare(K, "Accept-Ranges") == 0) + { + return V == "bytes"sv; + } + } + return false; + }; + + auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(R)) + { + return SupportsRanges(R); + } + return false; + }; + + if (ShouldResumeCheck(Res)) + { + // Find Content-Length + std::string ContentLengthValue; + for (const auto& [K, V] : Res.Headers) + { + if (StrCaseCompare(K, "Content-Length") == 0) + { + ContentLengthValue = V; + break; + } + } + + if (!ContentLengthValue.empty()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + KeyValueMap HeadersWithRange(AdditionalHeader); + uint8_t ResumeAttempt = 0; + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + break; // No progress, abort + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session ResumeSess = AllocSession(Url, {}); + CURL* ResumeH = ResumeSess.Get(); + + ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken())); + curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); + + std::vector<std::pair<std::string, std::string>> ResumeHeaders; + + struct ResumeHeaderCbData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + }; + + ResumeHeaderCbData ResumeHdrData; + ResumeHdrData.Headers = &ResumeHeaders; + ResumeHdrData.PayloadFile = &PayloadFile; + ResumeHdrData.PayloadString = &PayloadString; + + auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<ResumeHeaderCbData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)); + if (!Header) + { + return TotalBytes; + } + auto& [Key, Value] = *Header; + + if (StrCaseCompare(Key, "Content-Range") == 0) + { + if (Value.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) + { + const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = + *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) + { + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + } + } + return 0; + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + }; + + curl_easy_setopt(ResumeH, + CURLOPT_HEADERFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb)); + curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData); + curl_easy_setopt(ResumeH, + CURLOPT_WRITEFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData); + + Res = ResumeSess.Perform(); + Res.Headers = std::move(ResumeHeaders); + + ResumeAttempt++; + } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res)); + } + } + } + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Result), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h new file mode 100644 index 000000000..b7fa52e6c --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.h @@ -0,0 +1,137 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CurlHttpClient : public HttpClientBase +{ +public: + CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); + ~CurlHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct CurlResult + { + long StatusCode = 0; + std::string Body; + std::vector<std::pair<std::string, std::string>> Headers; + double ElapsedSeconds = 0; + int64_t UploadedBytes = 0; + int64_t DownloadedBytes = 0; + CURLcode ErrorCode = CURLE_OK; + std::string ErrorMessage; + }; + + struct Session + { + Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} + ~Session(); + + CURL* Get() const { return Handle; } + + // Takes ownership of the curl_slist and sets it on the handle. + // The list is freed automatically when the Session is destroyed. + void SetHeaders(curl_slist* Headers); + + // Low-level perform: executes the request and collects status/timing. + CurlResult Perform(); + + // Sets up standard write+header callbacks, performs the request, and + // moves the collected body and headers into the returned CurlResult. + CurlResult PerformWithResponseCallbacks(); + + LoggerRef Log() { return Outer->Log(); } + + private: + CurlHttpClient* Outer; + CURL* Handle; + curl_slist* HeaderList = nullptr; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters); + + RwLock m_SessionLock; + std::vector<CURL*> m_Sessions; + + void ReleaseSession(CURL* Handle); + + CurlResult DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + CurlResult DoWithRetry( + std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; }); + + bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + + static bool ShouldRetry(const CurlResult& Result); + + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; + + HttpClient::Response CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); + + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); +}; + +} // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..fbae9f5fe --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,641 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpwsclient.h> + +#include "../servers/wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <random> +#include <thread> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + CloseSocket(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void CloseSocket() + { + asio::error_code Ec; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + m_UnixSocket->close(Ec); + return; + } +#endif + if (m_TcpSocket) + { + m_TcpSocket->close(Ec); + } + } + + template<typename Fn> + void WithSocket(Fn&& Func) + { +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + Func(*m_UnixSocket); + return; + } +#endif + Func(*m_TcpSocket); + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard.emplace(m_IoContext.get_executor()); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_Settings.UnixSocketPath.empty()) + { + asio::post(m_IoContext, [this] { DoConnectUnix(); }); + return; + } +#endif + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoConnectUnix() + { + m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath); + CloseSocket(); + } + }); + + asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath)); + m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } +#endif + + void DoResolve() + { + m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + CloseSocket(); + } + }); + + asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); + + WithSocket([this](auto& Socket) { + asio::async_write(Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + }); + } + + void DoReadHandshakeResponse() + { + WithSocket([this](auto& Socket) { + asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + WithSocket([this](auto& Socket) { + asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector<uint8_t> Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + WithSocket([this, OwnedFrame](auto& Socket) { + asio::async_write(Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span<const uint8_t> Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; + std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket; +#endif + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{false}; + std::atomic<bool> m_CloseSent{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span<const uint8_t> Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen |