aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/clients')
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp380
-rw-r--r--src/zenhttp/clients/httpclientcommon.h120
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp791
-rw-r--r--src/zenhttp/clients/httpclientcpr.h34
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp1816
-rw-r--r--src/zenhttp/clients/httpclientcurl.h137
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp641
7 files changed, 3590 insertions, 329 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 47425e014..e4d11547a 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -142,7 +142,10 @@ namespace detail {
DataSize -= CopySize;
if (m_CacheBufferOffset == CacheBufferSize)
{
- AppendData(m_CacheBuffer, CacheBufferSize);
+ if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize))
+ {
+ return Ec;
+ }
if (DataSize > 0)
{
ZEN_ASSERT(DataSize < CacheBufferSize);
@@ -382,6 +385,177 @@ namespace detail {
return Result;
}
+ MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {}
+
+ bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue)
+ {
+ std::string LowerCaseValue = ToLower(ContentTypeHeaderValue);
+ if (LowerCaseValue.starts_with("multipart/byteranges"))
+ {
+ size_t BoundaryPos = LowerCaseValue.find("boundary=");
+ if (BoundaryPos != std::string::npos)
+ {
+ // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string
+ std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9);
+ size_t BoundaryEnd = std::string::npos;
+ while (!BoundaryName.empty() && BoundaryName[0] == ' ')
+ {
+ BoundaryName = BoundaryName.substr(1);
+ }
+ if (!BoundaryName.empty())
+ {
+ if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"')
+ {
+ BoundaryEnd = BoundaryName.find('"', 1);
+ if (BoundaryEnd != std::string::npos)
+ {
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1)));
+ return true;
+ }
+ }
+ else
+ {
+ BoundaryEnd = BoundaryName.find_first_of(" \r\n");
+ BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd)));
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ void MultipartBoundaryParser::ParseInput(std::string_view data)
+ {
+ const char* InputPtr = data.data();
+ size_t InputLength = data.length();
+ size_t ScanPos = 0;
+ while (ScanPos < InputLength)
+ {
+ const char ScanChar = InputPtr[ScanPos];
+ if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length()))
+ {
+ BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ break;
+ }
+ }
+
+ BoundaryHeader.Append(ScanChar);
+
+ HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+
+ if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete)
+ {
+ const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset();
+ const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset();
+ const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset;
+ std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength));
+
+ uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1;
+
+ uint64_t RangeOffset = 0;
+ uint64_t RangeLength = 0;
+ HttpContentType ContentType = HttpContentType::kBinary;
+
+ ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) {
+ const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line);
+ const std::string_view Key = KeyAndValue.first;
+ const std::string_view Value = KeyAndValue.second;
+ if (Key == "Content-Range")
+ {
+ std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value);
+ if (ContentRange.second != 0)
+ {
+ RangeOffset = ContentRange.first;
+ RangeLength = ContentRange.second;
+ }
+ }
+ else if (Key == "Content-Type")
+ {
+ ContentType = ParseContentType(Value);
+ }
+
+ return true;
+ });
+
+ if (RangeLength > 0)
+ {
+ Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload,
+ .RangeOffset = RangeOffset,
+ .RangeLength = RangeLength,
+ .ContentType = ContentType});
+ }
+
+ BoundaryBeginMatcher.Reset();
+ HeaderEndMatcher.Reset();
+ BoundaryEndMatcher.Reset();
+ BoundaryHeader.Reset();
+ }
+ }
+ else
+ {
+ BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar);
+ }
+ ScanPos++;
+ }
+ PayloadOffset += InputLength;
+ }
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString)
+ {
+ size_t DelimiterPos = HeaderString.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string_view Key = HeaderString.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string_view Value = HeaderString.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(HeaderString, std::string_view{});
+ }
+
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value)
+ {
+ if (Value.starts_with("bytes "))
+ {
+ size_t RangeSplitPos = Value.find('-', 6);
+ if (RangeSplitPos != std::string::npos)
+ {
+ size_t RangeEndLength = Value.find('/', RangeSplitPos + 1);
+ if (RangeEndLength == std::string::npos)
+ {
+ RangeEndLength = Value.length() - (RangeSplitPos + 1);
+ }
+ else
+ {
+ RangeEndLength = RangeEndLength - (RangeSplitPos + 1);
+ }
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ uint64_t RangeOffset = RequestedRangeStart.value();
+ uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1;
+ return std::make_pair(RangeOffset, RangeLength);
+ }
+ }
+ }
+ return {0, 0};
+ }
+
} // namespace detail
} // namespace zen
@@ -423,6 +597,8 @@ namespace testutil {
} // namespace testutil
+TEST_SUITE_BEGIN("http.httpclientcommon");
+
TEST_CASE("BufferedReadFileStream")
{
ScopedTemporaryDirectory TmpDir;
@@ -470,5 +646,207 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("ParseContentRange")
+{
+ SUBCASE("normal range with total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500");
+ CHECK_EQ(Offset, 0);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("non-zero offset")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878");
+ CHECK_EQ(Offset, 2638);
+ CHECK_EQ(Length, 5111437 - 2638 + 1);
+ }
+
+ SUBCASE("wildcard total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*");
+ CHECK_EQ(Offset, 100);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("no slash (total size omitted)")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 50-149");
+ CHECK_EQ(Offset, 50);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("malformed input returns zeros")
+ {
+ auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500");
+ CHECK_EQ(Offset1, 0);
+ CHECK_EQ(Length1, 0);
+
+ auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500");
+ CHECK_EQ(Offset2, 0);
+ CHECK_EQ(Length2, 0);
+
+ auto [Offset3, Length3] = detail::ParseContentRange("");
+ CHECK_EQ(Offset3, 0);
+ CHECK_EQ(Length3, 0);
+
+ auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500");
+ CHECK_EQ(Offset4, 0);
+ CHECK_EQ(Length4, 0);
+ }
+
+ SUBCASE("single byte range")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000");
+ CHECK_EQ(Offset, 42);
+ CHECK_EQ(Length, 1);
+ }
+}
+
+TEST_CASE("MultipartBoundaryParser")
+{
+ uint64_t Range1Offset = 2638;
+ uint64_t Range1Length = (5111437 - Range1Offset) + 1;
+
+ uint64_t Range2Offset = 5118199;
+ uint64_t Range2Length = (9147741 - Range2Offset) + 1;
+
+ std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229";
+ std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\"";
+
+ {
+ std::string_view Example1 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/44369878\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample1;
+ ParserExample1.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 7;
+ for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow)
+ {
+ ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample1.Boundaries.size() == 2);
+
+ CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example2 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample2;
+ ParserExample2.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow));
+ ParserExample2.ParseInput(Window);
+ }
+
+ CHECK(ParserExample2.Boundaries.size() == 2);
+
+ CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example3 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita";
+
+ detail::MultipartBoundaryParser ParserExample3;
+ ParserExample3.Init(ContentTypeHeaderValue2);
+
+ const size_t InputWindow = 31;
+ for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow)
+ {
+ ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow)));
+ }
+
+ CHECK(ParserExample3.Boundaries.size() == 2);
+
+ CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length);
+ }
+
+ {
+ std::string_view Example4 =
+ "\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 2638-5111437/*\r\n"
+ "Not: really\r\n"
+ "\r\n"
+ "datadatadatadata"
+ "\r\n--000000000bait0019229\r\n"
+ "\r\n--00\r\n--000000000bait001922\r\n"
+ "\r\n\r\n\r\r\n--00000000000000019229\r\n"
+ "Content-Type: application/x-ue-comp\r\n"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n"
+ "ditaditadita"
+ "Content-Type: application/x-ue-comp\r\n"
+ "ditaditadita"
+ "Content-Range: bytes 5118199-9147741/44369878\r\n"
+ "\r\n---\r\n--00000000000000019229--";
+
+ detail::MultipartBoundaryParser ParserExample4;
+ ParserExample4.Init(ContentTypeHeaderValue1);
+
+ const size_t InputWindow = 3;
+ for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow)
+ {
+ std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow));
+ ParserExample4.ParseInput(Window);
+ }
+
+ CHECK(ParserExample4.Boundaries.size() == 2);
+
+ CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset);
+ CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length);
+ CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset);
+ CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length);
+ }
+}
+
+TEST_SUITE_END();
+
} // namespace zen
#endif
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
index 1d0b7f9ea..e95e3a253 100644
--- a/src/zenhttp/clients/httpclientcommon.h
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -3,6 +3,7 @@
#pragma once
#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include <zenhttp/httpclient.h>
@@ -35,7 +36,10 @@ public:
const IoBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {}) = 0;
- [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) = 0;
[[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
[[nodiscard]] virtual Response Post(std::string_view Url,
const CompositeBuffer& Payload,
@@ -87,7 +91,7 @@ namespace detail {
std::error_code Write(std::string_view DataString);
IoBuffer DetachToIoBuffer();
IoBuffer BorrowIoBuffer();
- inline uint64_t GetSize() const { return m_WriteOffset; }
+ inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; }
void ResetWritePos(uint64_t WriteOffset);
private:
@@ -143,6 +147,118 @@ namespace detail {
uint64_t m_BytesLeftInSegment;
};
+ class IncrementalStringMatcher
+ {
+ public:
+ enum class EMatchState
+ {
+ None,
+ Partial,
+ Complete
+ };
+
+ EMatchState MatchState = EMatchState::None;
+
+ IncrementalStringMatcher() {}
+
+ IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString))
+ {
+ RawMatchString = MatchString.data();
+ }
+
+ void Init(std::string&& InMatchString)
+ {
+ MatchString = std::move(InMatchString);
+ RawMatchString = MatchString.data();
+ }
+
+ inline void Reset()
+ {
+ MatchLength = 0;
+ MatchStartOffset = 0;
+ MatchState = EMatchState::None;
+ }
+
+ inline uint64_t GetMatchEndOffset() const
+ {
+ if (MatchState == EMatchState::Complete)
+ {
+ return MatchStartOffset + MatchString.length();
+ }
+ return 0;
+ }
+
+ inline uint64_t GetMatchStartOffset() const
+ {
+ ZEN_ASSERT(MatchState == EMatchState::Complete);
+ return MatchStartOffset;
+ }
+
+ void Match(uint64_t Offset, char C)
+ {
+ ZEN_ASSERT_SLOW(RawMatchString != nullptr);
+
+ if (MatchState == EMatchState::Complete)
+ {
+ Reset();
+ }
+ if (C == RawMatchString[MatchLength])
+ {
+ if (MatchLength == 0)
+ {
+ MatchStartOffset = Offset;
+ }
+ MatchLength++;
+ if (MatchLength == MatchString.length())
+ {
+ MatchState = EMatchState::Complete;
+ }
+ else
+ {
+ MatchState = EMatchState::Partial;
+ }
+ }
+ else if (MatchLength != 0)
+ {
+ Reset();
+ Match(Offset, C);
+ }
+ else
+ {
+ Reset();
+ }
+ }
+ inline const std::string& GetMatchString() const { return MatchString; }
+
+ private:
+ std::string MatchString;
+ const char* RawMatchString = nullptr;
+ uint64_t MatchLength = 0;
+
+ uint64_t MatchStartOffset = 0;
+ };
+
+ class MultipartBoundaryParser
+ {
+ public:
+ std::vector<HttpClient::Response::MultipartBoundary> Boundaries;
+
+ MultipartBoundaryParser();
+ bool Init(const std::string_view ContentTypeHeaderValue);
+ void ParseInput(std::string_view data);
+
+ private:
+ IncrementalStringMatcher BoundaryBeginMatcher;
+ IncrementalStringMatcher BoundaryEndMatcher;
+ IncrementalStringMatcher HeaderEndMatcher;
+
+ ExtendableStringBuilder<64> BoundaryHeader;
+ uint64_t PayloadOffset = 0;
+ };
+
+ std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString);
+ std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value);
+
} // namespace detail
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 5d92b3b6b..a52b8f74b 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -7,11 +7,18 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/compactbinaryutil.h>
#include <zencore/compress.h>
+#include <zencore/filesystem.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
#include <zencore/session.h>
#include <zencore/stream.h>
#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/ssl_options.h>
+#include <cpr/unix_socket.h>
+ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
@@ -23,69 +30,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
-// If we want to support different HTTP client implementations then we'll need to make this more abstract
+//////////////////////////////////////////////////////////////////////////
-HttpClientError::ResponseClass
-HttpClientError::GetResponseClass() const
+static HttpClientErrorCode
+MapCprError(cpr::ErrorCode Code)
{
- if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
- {
- switch ((cpr::ErrorCode)m_Error)
- {
- case cpr::ErrorCode::CONNECTION_FAILURE:
- return ResponseClass::kHttpCantConnectError;
- case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
- case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return ResponseClass::kHttpNoHost;
- case cpr::ErrorCode::INTERNAL_ERROR:
- case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
- case cpr::ErrorCode::NETWORK_SEND_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
- return ResponseClass::kHttpTimeout;
- case cpr::ErrorCode::SSL_CONNECT_ERROR:
- case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_CACERT_ERROR:
- case cpr::ErrorCode::GENERIC_SSL_ERROR:
- return ResponseClass::kHttpSLLError;
- default:
- return ResponseClass::kHttpOtherClientError;
- }
- }
- else if (IsHttpSuccessCode(m_ResponseCode))
+ switch (Code)
{
- return ResponseClass::kSuccess;
- }
- else
- {
- switch (m_ResponseCode)
- {
- case HttpResponseCode::Unauthorized:
- return ResponseClass::kHttpUnauthorized;
- case HttpResponseCode::NotFound:
- return ResponseClass::kHttpNotFound;
- case HttpResponseCode::Forbidden:
- return ResponseClass::kHttpForbidden;
- case HttpResponseCode::Conflict:
- return ResponseClass::kHttpConflict;
- case HttpResponseCode::InternalServerError:
- return ResponseClass::kHttpInternalServerError;
- case HttpResponseCode::ServiceUnavailable:
- return ResponseClass::kHttpServiceUnavailable;
- case HttpResponseCode::BadGateway:
- return ResponseClass::kHttpBadGateway;
- case HttpResponseCode::GatewayTimeout:
- return ResponseClass::kHttpGatewayTimeout;
- default:
- if (m_ResponseCode >= HttpResponseCode::InternalServerError)
- {
- return ResponseClass::kHttpOtherServerError;
- }
- else
- {
- return ResponseClass::kHttpOtherClientError;
- }
- }
+ case cpr::ErrorCode::OK:
+ return HttpClientErrorCode::kOK;
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ return HttpClientErrorCode::kConnectionFailure;
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ return HttpClientErrorCode::kInternalError;
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ return HttpClientErrorCode::kSSLCACertError;
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return HttpClientErrorCode::kGenericSSLError;
+ case cpr::ErrorCode::REQUEST_CANCELLED:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
}
}
@@ -149,6 +129,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri,
{
}
+bool
+CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ // Quiet
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
CprHttpClient::~CprHttpClient()
{
ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
@@ -162,10 +154,11 @@ CprHttpClient::~CprHttpClient()
}
HttpClient::Response
-CprHttpClient::ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload)
+CprHttpClient::ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
// This ends up doing a memcpy, would be good to get rid of it by streaming results
// into buffer directly
@@ -174,30 +167,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId,
if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
{
const HttpContentType ContentType = ParseContentType(It->second);
-
ResponseBuffer.SetContentType(ContentType);
}
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
-
- if (!Quiet)
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
{
- if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ if (ShouldLogErrorCode(WorkResponseCode))
{
ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
}
}
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
return HttpClient::Response{.StatusCode = WorkResponseCode,
.ResponsePayload = std::move(ResponseBuffer),
.Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Ranges = std::move(BoundaryPositions)};
}
HttpClient::Response
-CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload)
+CprHttpClient::CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
{
const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
if (HttpResponse.error)
@@ -221,8 +221,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
.ElapsedSeconds = HttpResponse.elapsed,
- .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
- .ErrorMessage = HttpResponse.error.message}};
+ .Error =
+ HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}};
}
if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
@@ -235,7 +235,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe
}
else
{
- return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload));
+ return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
}
}
@@ -346,8 +346,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -385,8 +384,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId,
}
Sleep(100 * (Attempt + 1));
Attempt++;
- const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
- if (!Quiet)
+ if (ShouldLogErrorCode(HttpResponseCode(Result.status_code)))
{
ZEN_INFO("{} Attempt {}/{}",
CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
@@ -492,6 +490,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
}
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ CprSession->UpdateHeader({{"Connection", "close"}});
+ }
if (AccessToken)
{
CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
@@ -510,6 +512,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
CprSession->SetParameters({});
}
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath)));
+ }
+
+ if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty())
+ {
+ cpr::SslOptions SslOpts;
+ if (ConnectionSettings.InsecureSsl)
+ {
+ SslOpts.SetOption(cpr::ssl::VerifyHost{false});
+ SslOpts.SetOption(cpr::ssl::VerifyPeer{false});
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath});
+ }
+ CprSession->SetSslOptions(SslOpts);
+ }
+
ExtendableStringBuilder<128> UrlBuffer;
UrlBuffer << BaseUrl << ResourcePath;
CprSession->SetUrl(UrlBuffer.c_str());
@@ -621,7 +643,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke
ResponseBuffer.SetContentType(ContentType);
}
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)};
}
//////////////////////////////////////////////////////////////////////////
@@ -774,22 +796,97 @@ CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTyp
}
CprHttpClient::Response
-CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+CprHttpClient::Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader,
+ const std::filesystem::path& TempFolderPath)
{
ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload");
- return CommonResponse(
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ cpr::Response Response = DoWithRetry(
m_SessionId,
- DoWithRetry(m_SessionId,
- [&]() {
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ [&]() {
+ PayloadString.clear();
+ PayloadFile.reset();
- Sess->SetBody(AsCprBody(Payload));
- Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
- return Sess.Post();
- }),
- {});
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+ cpr::Response Response = Sess.Post({}, cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+ return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
}
CprHttpClient::Response
@@ -896,236 +993,292 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF
std::string PayloadString;
std::unique_ptr<detail::TempPayloadFile> PayloadFile;
- cpr::Response Response = DoWithRetry(
- m_SessionId,
- [&]() {
- auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
- size_t DelimiterPos = header.find(':');
- if (DelimiterPos != std::string::npos)
- {
- std::string Key = header.substr(0, DelimiterPos);
- constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
- Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
- Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
-
- std::string Value = header.substr(DelimiterPos + 1);
- Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
- Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
-
- return std::make_pair(Key, Value);
- }
- return std::make_pair(header, "");
- };
-
- auto DownloadCallback = [&](std::string data, intptr_t) {
- if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
- {
- return false;
- }
- if (PayloadFile)
- {
- ZEN_ASSERT(PayloadString.empty());
- std::error_code Ec = PayloadFile->Write(data);
- if (Ec)
- {
- ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- return false;
- }
- }
- else
- {
- PayloadString.append(data);
- }
- return true;
- };
-
- uint64_t RequestedContentLength = (uint64_t)-1;
- if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
- {
- if (RangeIt->second.starts_with("bytes"))
- {
- size_t RangeStartPos = RangeIt->second.find('=', 5);
- if (RangeStartPos != std::string::npos)
- {
- RangeStartPos++;
- size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
- if (RangeSplitPos != std::string::npos)
- {
- std::optional<size_t> RequestedRangeStart =
- ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
- std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
- if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
- {
- RequestedContentLength = RequestedRangeEnd.value() - 1;
- }
- }
- }
- }
- }
-
- cpr::Response Response;
- {
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (Header.first == "Content-Length"sv)
- {
- std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
- if (ContentLength.has_value())
- {
- if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
- {
- PayloadFile = std::make_unique<detail::TempPayloadFile>();
- std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
- if (Ec)
- {
- ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- PayloadFile.reset();
- }
- }
- else
- {
- PayloadString.reserve(ContentLength.value());
- }
- }
- }
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
- return 1;
- };
-
- Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- }
- if (m_ConnectionSettings.AllowResume)
- {
- auto SupportsRanges = [](const cpr::Response& Response) -> bool {
- if (Response.header.find("Content-Range") != Response.header.end())
- {
- return true;
- }
- if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
- {
- return It->second == "bytes"sv;
- }
- return false;
- };
-
- auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
- if (ShouldRetry(Response))
- {
- return SupportsRanges(Response);
- }
- return false;
- };
-
- if (ShouldResume(Response))
- {
- auto It = Response.header.find("Content-Length");
- if (It != Response.header.end())
- {
- uint64_t ContentLength = RequestedContentLength;
- if (ContentLength == uint64_t(-1))
- {
- if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
- {
- ContentLength = ParsedContentLength.value();
- }
- }
-
- std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
-
- auto HeaderCallback = [&](std::string header, intptr_t) {
- std::pair<std::string, std::string> Header = GetHeader(header);
- if (!Header.first.empty())
- {
- ReceivedHeaders.emplace_back(std::move(Header));
- }
-
- if (Header.first == "Content-Range"sv)
- {
- if (Header.second.starts_with("bytes "sv))
- {
- size_t RangeStartEnd = Header.second.find('-', 6);
- if (RangeStartEnd != std::string::npos)
- {
- const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
- if (Start)
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
- if (Start.value() == DownloadedSize)
- {
- return 1;
- }
- else if (Start.value() > DownloadedSize)
- {
- return 0;
- }
- if (PayloadFile)
- {
- PayloadFile->ResetWritePos(Start.value());
- }
- else
- {
- PayloadString = PayloadString.substr(0, Start.value());
- }
- return 1;
- }
- }
- }
- return 0;
- }
- return 1;
- };
-
- KeyValueMap HeadersWithRange(AdditionalHeader);
- do
- {
- uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
-
- std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
- if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
- {
- if (RangeIt->second == Range)
- {
- // If we didn't make any progress, abort
- break;
- }
- }
- HeadersWithRange.Entries.insert_or_assign("Range", Range);
-
- Session Sess =
- AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
- Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
- for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
- {
- Response.header.insert_or_assign(H.first, H.second);
- }
- ReceivedHeaders.clear();
- } while (ShouldResume(Response));
- }
- }
- }
-
- if (!PayloadString.empty())
- {
- Response.text = std::move(PayloadString);
- }
- return Response;
- },
- PayloadFile);
- return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+
+ if (IsMultiRangeResponse)
+ {
+ BoundaryParser.ParseInput(data);
+ }
+
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ if (RequestedContentLength != (uint64_t)-1 && RequestedContentLength > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ ZEN_DEBUG("Multirange request");
+ }
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ const std::string Key(Header.first);
+ if (StrCaseCompare(Key.c_str(), "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0)
+ {
+ IsMultiRangeResponse = BoundaryParser.Init(Header.second);
+ if (!IsMultiRangeResponse)
+ {
+ ContentType = ParseContentType(Header.second);
+ }
+ }
+ else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0)
+ {
+ if (!IsMultiRangeResponse)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second);
+ if (Range.second != 0)
+ {
+ BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = ContentType});
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Range") == 0)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Response),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
}
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
index 40af53b5d..009e6fb7a 100644
--- a/src/zenhttp/clients/httpclientcpr.h
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -38,7 +38,10 @@ public:
const IoBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {}) override;
- [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) override;
[[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
[[nodiscard]] virtual Response Post(std::string_view Url,
const CompositeBuffer& Payload,
@@ -104,15 +107,27 @@ private:
CprSession->SetReadCallback({});
return Result;
}
- inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {})
+ inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {},
+ std::optional<cpr::WriteCallback>&& Write = {},
+ std::optional<cpr::HeaderCallback>&& Header = {})
{
ZEN_TRACE_CPU("HttpClient::Impl::Post");
if (Read)
{
CprSession->SetReadCallback(std::move(Read.value()));
}
+ if (Write)
+ {
+ CprSession->SetWriteCallback(std::move(Write.value()));
+ }
+ if (Header)
+ {
+ CprSession->SetHeaderCallback(std::move(Header.value()));
+ }
cpr::Response Result = CprSession->Post();
ZEN_TRACE("POST {}", Result);
+ CprSession->SetHeaderCallback({});
+ CprSession->SetWriteCallback({});
CprSession->SetReadCallback({});
return Result;
}
@@ -155,14 +170,19 @@ private:
std::function<cpr::Response()>&& Func,
std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; });
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
- HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload);
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
- HttpClient::Response ResponseWithPayload(std::string_view SessionId,
- cpr::Response&& HttpResponse,
- const HttpResponseCode WorkResponseCode,
- IoBuffer&& Payload);
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ cpr::Response&& HttpResponse,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
};
} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
new file mode 100644
index 000000000..ec9b7bac6
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -0,0 +1,1816 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcurl.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+namespace zen {
+
+HttpClientBase*
+CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
+{
+ return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+}
+
+static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback helpers
+
+struct WriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<WriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct HeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
+// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
+static std::optional<std::pair<std::string_view, std::string_view>>
+ParseHeaderLine(std::string_view Line)
+{
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return std::nullopt;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ return std::nullopt;
+ }
+
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ return std::pair{Key, Value};
+}
+
+static size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<HeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct ReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<ReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+struct StreamReadCallbackData
+{
+ detail::CompositeBufferReadStream* Reader = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<StreamReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ return Data->Reader->Read(Buffer, MaxRead);
+}
+
+struct FileReadCallbackData
+{
+ detail::BufferedReadFileStream* Buffer = nullptr;
+ uint64_t TotalSize = 0;
+ uint64_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<FileReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->TotalSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ Data->Buffer->Read(Buffer, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+static int
+CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr)
+{
+ ZEN_UNUSED(Handle);
+ LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr);
+ auto Log = [&]() -> LoggerRef { return LogRef; };
+
+ std::string_view DataView(Data, Size);
+
+ // Trim trailing newlines
+ while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n'))
+ {
+ DataView.remove_suffix(1);
+ }
+
+ switch (Type)
+ {
+ case CURLINFO_TEXT:
+ if (DataView.find("need more data"sv) == std::string_view::npos)
+ {
+ ZEN_INFO("TEXT: {}", DataView);
+ }
+ break;
+ case CURLINFO_HEADER_IN:
+ ZEN_INFO("HIN : {}", DataView);
+ break;
+ case CURLINFO_HEADER_OUT:
+ if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos)
+ {
+ std::string Copy(DataView);
+ auto BearerStart = TokenPos + 22;
+ auto BearerEnd = Copy.find_first_of("\r\n", BearerStart);
+ if (BearerEnd == std::string::npos)
+ {
+ BearerEnd = Copy.length();
+ }
+ Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart));
+ ZEN_INFO("HOUT: {}", Copy);
+ }
+ else
+ {
+ ZEN_INFO("HOUT: {}", DataView);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+static curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ ExtendableStringBuilder<64> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ ExtendableStringBuilder<64> SessionHeader;
+ SessionHeader << "UE-Session: " << SessionId;
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken)
+ {
+ ExtendableStringBuilder<128> AuthHeader;
+ AuthHeader << "Authorization: " << AccessToken->Value;
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ ExtendableStringBuilder<128> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+static HttpClient::KeyValueMap
+BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+ return HeaderMap;
+}
+
+// Scans response headers for Content-Type and applies it to the buffer.
+static void
+ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ for (const auto& [Key, Value] : Headers)
+ {
+ if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ Buffer.SetContentType(ParseContentType(Value));
+ break;
+ }
+ }
+}
+
+static void
+AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
+{
+ static constexpr char HexDigits[] = "0123456789ABCDEF";
+ static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
+
+ for (char C : Input)
+ {
+ if (Unreserved.Contains(C))
+ {
+ Out.Append(C);
+ }
+ else
+ {
+ uint8_t Byte = static_cast<uint8_t>(C);
+ char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
+ Out.Append(std::string_view(Encoded, 3));
+ }
+ }
+}
+
+static void
+BuildUrlWithParameters(StringBuilderBase& Url,
+ std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClient::KeyValueMap& Parameters)
+{
+ Url.Append(BaseUrl);
+ Url.Append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ Url.Append(Separator);
+ AppendUrlEncoded(Url, Key);
+ Url.Append('=');
+ AppendUrlEncoded(Url, Value);
+ Separator = '&';
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction)
+: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction))
+{
+}
+
+CurlHttpClient::~CurlHttpClient()
+{
+ ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto* Handle : m_Sessions)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_Sessions.clear();
+ });
+}
+
+CurlHttpClient::Session::~Session()
+{
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ Outer->ReleaseSession(Handle);
+}
+
+void
+CurlHttpClient::Session::SetHeaders(curl_slist* Headers)
+{
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ HeaderList = Headers;
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList);
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::PerformWithResponseCallbacks()
+{
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body,
+ .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::Perform()
+{
+ CurlResult Result;
+
+ char ErrorBuffer[CURL_ERROR_SIZE] = {};
+ curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer);
+
+ Result.ErrorCode = curl_easy_perform(Handle);
+
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode);
+ }
+
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+ Result.ElapsedSeconds = Elapsed;
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+ Result.UploadedBytes = static_cast<int64_t>(UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+ Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+
+ return Result;
+}
+
+bool
+CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
+HttpClient::Response
+CurlHttpClient::ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size());
+
+ ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers);
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ if (ShouldLogErrorCode(WorkResponseCode))
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}",
+ SessionId,
+ static_cast<int>(WorkResponseCode),
+ m_BaseUri);
+ }
+ }
+
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Ranges = std::move(BoundaryPositions)};
+}
+
+HttpClient::Response
+CurlHttpClient::CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode);
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
+ if (!Quiet)
+ {
+ if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
+ Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ SessionId,
+ static_cast<int>(Result.ErrorCode),
+ Result.ErrorMessage);
+ }
+ }
+
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()),
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload))
+ {
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = BuildHeaderMap(Result.Headers),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ else
+ {
+ return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
+ }
+}
+
+bool
+CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+
+ IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+
+ // Collect relevant headers in a single pass
+ std::string_view ContentLengthValue;
+ std::string_view IoHashValue;
+ std::string_view ContentTypeValue;
+
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ ContentLengthValue = Value;
+ }
+ else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0)
+ {
+ IoHashValue = Value;
+ }
+ else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ ContentTypeValue = Value;
+ }
+ }
+
+ // Validate Content-Length
+ if (!ContentLengthValue.empty())
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue);
+ if (!ExpectedContentSize.has_value())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue);
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage =
+ fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue);
+ return false;
+ }
+ }
+
+ if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent))
+ {
+ return true;
+ }
+
+ // Validate X-Jupiter-IoHash
+ if (!IoHashValue.empty())
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString());
+ return false;
+ }
+ }
+ }
+
+ // Validate content-type specific payload
+ if (ContentTypeValue == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer,
+ RawHash,
+ RawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr))
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = "Compressed binary failed validation";
+ return false;
+ }
+ }
+ if (ContentTypeValue == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error));
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool
+CurlHttpClient::ShouldRetry(const CurlResult& Result)
+{
+ switch (Result.ErrorCode)
+ {
+ case CURLE_OK:
+ break;
+ case CURLE_RECV_ERROR:
+ case CURLE_SEND_ERROR:
+ case CURLE_OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch (static_cast<HttpResponseCode>(Result.StatusCode))
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}",
+ SessionId,
+ static_cast<int>(MapCurlError(Result.ErrorCode)),
+ Result.ErrorMessage,
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ else
+ {
+ ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}",
+ SessionId,
+ Result.StatusCode,
+ zen::ToString(HttpResponseCode(Result.StatusCode)),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Session
+CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::AllocSession");
+ CURL* Handle = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ Handle = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (Handle == nullptr)
+ {
+ Handle = curl_easy_init();
+ if (Handle == nullptr)
+ {
+ ThrowOutOfMemory("curl_easy_init");
+ }
+ }
+ else
+ {
+ curl_easy_reset(Handle);
+ }
+
+ // Unix domain socket
+ if (!m_ConnectionSettings.UnixSocketPath.empty())
+ {
+ std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath);
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str());
+ }
+
+ // Build URL with parameters
+ ExtendableStringBuilder<256> Url;
+ BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Timeouts
+ if (m_ConnectionSettings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count()));
+ }
+ if (m_ConnectionSettings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (m_ConnectionSettings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // Verbose/debug
+ if (m_ConnectionSettings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log);
+ }
+
+ // SSL options
+ if (m_ConnectionSettings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!m_ConnectionSettings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str());
+ }
+
+ // Disable signal handling for thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (m_ConnectionSettings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+
+ // Note: Headers are NOT set here. Each method builds its own header list
+ // (potentially adding method-specific headers like Content-Type) and passes
+ // ownership to the Session via SetHeaders().
+
+ return Session(this, Handle);
+}
+
+void
+CurlHttpClient::ReleaseSession(CURL* Handle)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession");
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+// TransactPackage is a two-phase protocol (offer + send) with server-side state
+// between phases, so retrying individual phases would be incorrect.
+CurlHttpClient::Response
+CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::TransactPackage");
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++CurlHttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (!Attachments.empty())
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ std::vector<std::pair<std::string, std::string>> OfferExtraHeaders;
+ OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer));
+ OfferExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders));
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size()));
+
+ CurlResult Result = Sess.PerformWithResponseCallbacks();
+
+ if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode))
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ std::vector<std::pair<std::string, std::string>> PkgExtraHeaders;
+ PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage));
+ PkgExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders));
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize()));
+
+ CurlResult Result = Sess.PerformWithResponseCallbacks();
+
+ return CommonResponse(m_SessionId, std::move(Result), {}, {});
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Session Sess = AllocSession(Url, Parameters);
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Get");
+ return CommonResponse(m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, Parameters);
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L);
+ return Sess.PerformWithResponseCallbacks();
+ },
+ [this](CurlResult& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Head");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L);
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Delete");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE");
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, Parameters);
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L);
+ curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L);
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader,
+ const std::filesystem::path& TempFolderPath)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ PayloadString.clear();
+ PayloadFile.reset();
+
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}));
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize()));
+
+ struct PostHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ };
+
+ PostHeaderCallbackData PostHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ PostHdrData.Headers = &ResponseHeaders;
+ PostHdrData.PayloadFile = &PayloadFile;
+ PostHdrData.PayloadString = &PayloadString;
+ PostHdrData.TempFolderPath = &TempFolderPath;
+ PostHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ PostHdrData.Log = m_Log;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<PostHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+
+ if (StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (!Data->TempFolderPath->empty() && ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &PostHdrData);
+
+ struct PostWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ };
+
+ PostWriteCallbackData PostWriteData;
+ PostWriteData.PayloadString = &PayloadString;
+ PostWriteData.PayloadFile = &PayloadFile;
+ PostWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ PostWriteData.TempFolderPath = &TempFolderPath;
+ PostWriteData.Log = m_Log;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<PostWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &PostWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId, std::move(Result), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}));
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ return Sess.PerformWithResponseCallbacks();
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(Url, {});
+ CURL* H = Sess.Get();
+
+ Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ // Track requested content length from Range header (sum all ranges)
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string_view::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string_view::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string_view::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ // Header callback that detects Content-Length and switches to file-backed storage when needed
+ struct DownloadHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ HttpContentType* ContentTypeOut = nullptr;
+ };
+
+ DownloadHeaderCallbackData DlHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ DlHdrData.Headers = &ResponseHeaders;
+ DlHdrData.PayloadFile = &PayloadFile;
+ DlHdrData.PayloadString = &PayloadString;
+ DlHdrData.TempFolderPath = &TempFolderPath;
+ DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ DlHdrData.Log = m_Log;
+ DlHdrData.BoundaryParser = &BoundaryParser;
+ DlHdrData.IsMultiRange = &IsMultiRangeResponse;
+ DlHdrData.ContentTypeOut = &ContentType;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [KeyView, Value] = *Header;
+ const std::string Key(KeyView);
+
+ if (StrCaseCompare(Key, "Content-Length") == 0)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ *Data->IsMultiRange = Data->BoundaryParser->Init(Value);
+ if (!*Data->IsMultiRange)
+ {
+ *Data->ContentTypeOut = ParseContentType(Value);
+ }
+ }
+ else if (StrCaseCompare(Key, "Content-Range") == 0)
+ {
+ if (!*Data->IsMultiRange)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value);
+ if (Range.second != 0)
+ {
+ Data->BoundaryParser->Boundaries.push_back(
+ HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = *Data->ContentTypeOut});
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(Key, std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData);
+
+ // Write callback that directs data to file or string
+ struct DownloadWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ };
+
+ DownloadWriteCallbackData DlWriteData;
+ DlWriteData.PayloadString = &PayloadString;
+ DlWriteData.PayloadFile = &PayloadFile;
+ DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ DlWriteData.TempFolderPath = &TempFolderPath;
+ DlWriteData.Log = m_Log;
+ DlWriteData.BoundaryParser = &BoundaryParser;
+ DlWriteData.IsMultiRange = &IsMultiRangeResponse;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->IsMultiRange)
+ {
+ Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes));
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ // Handle resume logic
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const CurlResult& R) -> bool {
+ for (const auto& [K, V] : R.Headers)
+ {
+ if (StrCaseCompare(K, "Content-Range") == 0)
+ {
+ return true;
+ }
+ if (StrCaseCompare(K, "Accept-Ranges") == 0)
+ {
+ return V == "bytes"sv;
+ }
+ }
+ return false;
+ };
+
+ auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(R))
+ {
+ return SupportsRanges(R);
+ }
+ return false;
+ };
+
+ if (ShouldResumeCheck(Res))
+ {
+ // Find Content-Length
+ std::string ContentLengthValue;
+ for (const auto& [K, V] : Res.Headers)
+ {
+ if (StrCaseCompare(K, "Content-Length") == 0)
+ {
+ ContentLengthValue = V;
+ break;
+ }
+ }
+
+ if (!ContentLengthValue.empty())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ uint8_t ResumeAttempt = 0;
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ break; // No progress, abort
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session ResumeSess = AllocSession(Url, {});
+ CURL* ResumeH = ResumeSess.Get();
+
+ ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()));
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L);
+
+ std::vector<std::pair<std::string, std::string>> ResumeHeaders;
+
+ struct ResumeHeaderCbData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ };
+
+ ResumeHeaderCbData ResumeHdrData;
+ ResumeHdrData.Headers = &ResumeHeaders;
+ ResumeHdrData.PayloadFile = &PayloadFile;
+ ResumeHdrData.PayloadString = &PayloadString;
+
+ auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<ResumeHeaderCbData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes));
+ if (!Header)
+ {
+ return TotalBytes;
+ }
+ auto& [Key, Value] = *Header;
+
+ if (StrCaseCompare(Key, "Content-Range") == 0)
+ {
+ if (Value.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Value.find('-', 6);
+ if (RangeStartEnd != std::string_view::npos)
+ {
+ const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize =
+ *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length();
+ if (Start.value() == DownloadedSize)
+ {
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (*Data->PayloadFile)
+ {
+ (*Data->PayloadFile)->ResetWritePos(Start.value());
+ }
+ else
+ {
+ *Data->PayloadString = Data->PayloadString->substr(0, Start.value());
+ }
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ }
+ }
+ return 0;
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(ResumeH,
+ CURLOPT_HEADERFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb));
+ curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData);
+ curl_easy_setopt(ResumeH,
+ CURLOPT_WRITEFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData);
+
+ Res = ResumeSess.Perform();
+ Res.Headers = std::move(ResumeHeaders);
+
+ ResumeAttempt++;
+ } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Result),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
new file mode 100644
index 000000000..b7fa52e6c
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -0,0 +1,137 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httpclientcommon.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CurlHttpClient : public HttpClientBase
+{
+public:
+ CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction);
+ ~CurlHttpClient();
+
+ // HttpClientBase
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ CbObject Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const std::filesystem::path& TempFolderPath = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url,
+ CbPackage Package,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+private:
+ struct CurlResult
+ {
+ long StatusCode = 0;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> Headers;
+ double ElapsedSeconds = 0;
+ int64_t UploadedBytes = 0;
+ int64_t DownloadedBytes = 0;
+ CURLcode ErrorCode = CURLE_OK;
+ std::string ErrorMessage;
+ };
+
+ struct Session
+ {
+ Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {}
+ ~Session();
+
+ CURL* Get() const { return Handle; }
+
+ // Takes ownership of the curl_slist and sets it on the handle.
+ // The list is freed automatically when the Session is destroyed.
+ void SetHeaders(curl_slist* Headers);
+
+ // Low-level perform: executes the request and collects status/timing.
+ CurlResult Perform();
+
+ // Sets up standard write+header callbacks, performs the request, and
+ // moves the collected body and headers into the returned CurlResult.
+ CurlResult PerformWithResponseCallbacks();
+
+ LoggerRef Log() { return Outer->Log(); }
+
+ private:
+ CurlHttpClient* Outer;
+ CURL* Handle;
+ curl_slist* HeaderList = nullptr;
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+ };
+
+ Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters);
+
+ RwLock m_SessionLock;
+ std::vector<CURL*> m_Sessions;
+
+ void ReleaseSession(CURL* Handle);
+
+ CurlResult DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+ CurlResult DoWithRetry(
+ std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; });
+
+ bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+
+ static bool ShouldRetry(const CurlResult& Result);
+
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
+
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
+
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..fbae9f5fe
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,641 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ CloseSocket();
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void CloseSocket()
+ {
+ asio::error_code Ec;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ m_UnixSocket->close(Ec);
+ return;
+ }
+#endif
+ if (m_TcpSocket)
+ {
+ m_TcpSocket->close(Ec);
+ }
+ }
+
+ template<typename Fn>
+ void WithSocket(Fn&& Func)
+ {
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ Func(*m_UnixSocket);
+ return;
+ }
+#endif
+ Func(*m_TcpSocket);
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard.emplace(m_IoContext.get_executor());
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ asio::post(m_IoContext, [this] { DoConnectUnix(); });
+ return;
+ }
+#endif
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void DoConnectUnix()
+ {
+ m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath);
+ CloseSocket();
+ }
+ });
+
+ asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath));
+ m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+#endif
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ CloseSocket();
+ }
+ });
+
+ asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ WithSocket([this](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ WithSocket([this](auto& Socket) {
+ asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ WithSocket([this](auto& Socket) {
+ asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ WithSocket([this, OwnedFrame](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket;
+#endif
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen