aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients/httpclientcpr.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2025-09-30 19:07:51 +0200
committerGitHub Enterprise <[email protected]>2025-09-30 19:07:51 +0200
commit634181a04efff90def7a98d98eac7078e1d4e62d (patch)
tree04678bba636a76d21f300ff6e73af4473274cf12 /src/zenhttp/clients/httpclientcpr.cpp
parentuse batching clang-format for quicker turnaround on validate actions (#529) (diff)
downloadzen-634181a04efff90def7a98d98eac7078e1d4e62d.tar.xz
zen-634181a04efff90def7a98d98eac7078e1d4e62d.zip
HttpClient support for pluggable back-ends (#532)
refactored HttpClient to separate out cpr implementation into separate classes, with an abstract base class to allow plugging in multiple implementations in the future
Diffstat (limited to 'src/zenhttp/clients/httpclientcpr.cpp')
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp1035
1 files changed, 1035 insertions, 0 deletions
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
new file mode 100644
index 000000000..568106887
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -0,0 +1,1035 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcpr.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zenhttp/packageformat.h>
+
+namespace zen {
+
+HttpClientBase*
+CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings)
+{
+ return new CprHttpClient(BaseUri, ConnectionSettings);
+}
+
+static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+
+// If we want to support different HTTP client implementations then we'll need to make this more abstract
+
+HttpClientError::ResponseClass
+HttpClientError::GetResponseClass() const
+{
+ if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
+ {
+ switch ((cpr::ErrorCode)m_Error)
+ {
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ return ResponseClass::kHttpCantConnectError;
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return ResponseClass::kHttpNoHost;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return ResponseClass::kHttpTimeout;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return ResponseClass::kHttpSLLError;
+ default:
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ else if (IsHttpSuccessCode(m_ResponseCode))
+ {
+ return ResponseClass::kSuccess;
+ }
+ else
+ {
+ switch (m_ResponseCode)
+ {
+ case HttpResponseCode::Unauthorized:
+ return ResponseClass::kHttpUnauthorized;
+ case HttpResponseCode::NotFound:
+ return ResponseClass::kHttpNotFound;
+ case HttpResponseCode::Forbidden:
+ return ResponseClass::kHttpForbidden;
+ case HttpResponseCode::Conflict:
+ return ResponseClass::kHttpConflict;
+ case HttpResponseCode::InternalServerError:
+ return ResponseClass::kHttpInternalServerError;
+ case HttpResponseCode::ServiceUnavailable:
+ return ResponseClass::kHttpServiceUnavailable;
+ case HttpResponseCode::BadGateway:
+ return ResponseClass::kHttpBadGateway;
+ case HttpResponseCode::GatewayTimeout:
+ return ResponseClass::kHttpGatewayTimeout;
+ default:
+ if (m_ResponseCode >= HttpResponseCode::InternalServerError)
+ {
+ return ResponseClass::kHttpOtherServerError;
+ }
+ else
+ {
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// CPR helpers
+
+static cpr::Body
+AsCprBody(const CbObject& Obj)
+{
+ return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize());
+}
+
+static cpr::Body
+AsCprBody(const IoBuffer& Obj)
+{
+ return cpr::Body((const char*)Obj.GetData(), Obj.GetSize());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClient::Response
+ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload)
+{
+ // This ends up doing a memcpy, would be good to get rid of it by streaming results
+ // into buffer directly
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size());
+
+ if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
+ {
+ const HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed};
+}
+
+static HttpClient::Response
+CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {})
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
+ if (HttpResponse.error)
+ {
+ if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT &&
+ HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse);
+ }
+
+ // Client side failure code
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()),
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
+ .ErrorMessage = HttpResponse.error.message}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
+ {
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed};
+ }
+ else
+ {
+ return ResponseWithPayload(
+ SessionId,
+ HttpResponse,
+ WorkResponseCode,
+ Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()));
+ }
+}
+
+static bool
+ShouldRetry(const cpr::Response& Response)
+{
+ switch (Response.error.code)
+ {
+ case cpr::ErrorCode::OK:
+ break;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch ((HttpResponseCode)Response.status_code)
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+};
+
+static bool
+ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+ IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size());
+
+ if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end())
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second);
+ if (!ExpectedContentSize.has_value())
+ {
+ Response.error =
+ cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second));
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Response.error = cpr::Error(
+ /*CURLE_READ_ERROR*/ 26,
+ fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second));
+ return false;
+ }
+ }
+
+ if (Response.status_code == (long)HttpResponseCode::PartialContent)
+ {
+ return true;
+ }
+
+ if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end())
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26,
+ fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString()));
+ return false;
+ }
+ }
+ }
+
+ if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end())
+ {
+ if (ContentType->second == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize))
+ {
+ return true;
+ }
+ else
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation");
+ return false;
+ }
+ }
+ if (ContentType->second == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error)));
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+static cpr::Response
+DoWithRetry(
+ std::string_view SessionId,
+ std::function<cpr::Response()>&& Func,
+ uint8_t RetryCount,
+ std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; })
+{
+ uint8_t Attempt = 0;
+ cpr::Response Result = Func();
+ while (Attempt < RetryCount)
+ {
+ if (!ShouldRetry(Result))
+ {
+ if (Result.error || !IsHttpSuccessCode(Result.status_code))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1);
+ Result = Func();
+ }
+ return Result;
+}
+
+static cpr::Response
+DoWithRetry(std::string_view SessionId,
+ std::function<cpr::Response()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile,
+ uint8_t RetryCount)
+{
+ uint8_t Attempt = 0;
+ cpr::Response Result = Func();
+ while (Attempt < RetryCount)
+ {
+ if (!ShouldRetry(Result))
+ {
+ if (Result.error || !IsHttpSuccessCode(Result.status_code))
+ {
+ break;
+ }
+ if (ValidatePayload(Result, PayloadFile))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1);
+ Result = Func();
+ }
+ return Result;
+}
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings)
+: HttpClientBase(BaseUri, Connectionsettings)
+{
+}
+
+CprHttpClient::~CprHttpClient()
+{
+ ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto CprSession : m_Sessions)
+ {
+ delete CprSession;
+ }
+ m_Sessions.clear();
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CprHttpClient::Session
+CprHttpClient::AllocSession(const std::string_view BaseUrl,
+ const std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken)
+{
+ ZEN_TRACE_CPU("CprHttpClient::AllocSession");
+ cpr::Session* CprSession = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ CprSession = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (CprSession == nullptr)
+ {
+ CprSession = new cpr::Session();
+ CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout);
+ CprSession->SetTimeout(ConnectionSettings.Timeout);
+ if (ConnectionSettings.AssumeHttp2)
+ {
+ CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE});
+ }
+ }
+
+ if (!AdditionalHeader->empty())
+ {
+ CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end()));
+ }
+ if (!SessionId.empty())
+ {
+ CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
+ }
+ if (AccessToken)
+ {
+ CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
+ }
+ if (!Parameters->empty())
+ {
+ cpr::Parameters Tmp;
+ for (auto It = Parameters->begin(); It != Parameters->end(); It++)
+ {
+ Tmp.Add({It->first, It->second});
+ }
+ CprSession->SetParameters(Tmp);
+ }
+ else
+ {
+ CprSession->SetParameters({});
+ }
+
+ ExtendableStringBuilder<128> UrlBuffer;
+ UrlBuffer << BaseUrl << ResourcePath;
+ CprSession->SetUrl(UrlBuffer.c_str());
+
+ return Session(this, CprSession);
+}
+
+void
+CprHttpClient::ReleaseSession(cpr::Session* CprSession)
+{
+ ZEN_TRACE_CPU("CprHttpClient::ReleaseSession");
+ CprSession->SetUrl({});
+ CprSession->SetHeader({});
+ CprSession->SetBody({});
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); });
+}
+
+CprHttpClient::Response
+CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::TransactPackage");
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++HttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}});
+ Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (FilterResponse.status_code == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ else
+ {
+ // This should be an error -- server asked to have something we can't find
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}});
+ Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (!IsHttpSuccessCode(FilterResponse.status_code))
+ {
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code)};
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
+
+ if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
+ {
+ HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CprHttpClient::Response
+CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())});
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Put");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri,
+ Url,
+ m_ConnectionSettings,
+ {{"Content-Length", "0"}},
+ Parameters,
+ m_SessionId,
+ GetAccessToken());
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Get");
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ return Sess.Get();
+ },
+ m_ConnectionSettings.RetryCount,
+ [](cpr::Response& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }));
+}
+
+CprHttpClient::Response
+CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Head");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ return Sess.Head();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Delete");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ return Sess.Delete();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostNoPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ uint64_t Offset = 0;
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+ auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, Payload.GetSize() - Offset);
+ Buffer.Read(buffer, size);
+ Offset += size;
+ return true;
+ };
+ return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ }
+ Sess->SetBody(AsCprBody(Payload));
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+ auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) {
+ size = Reader.Read(buffer, size);
+ return true;
+ };
+ return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())});
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ uint64_t Offset = 0;
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+ auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, Payload.GetSize() - Offset);
+ Buffer.Read(buffer, size);
+ Offset += size;
+ return true;
+ };
+ return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ }
+ Sess->SetBody(AsCprBody(Payload));
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+ auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) {
+ size = Reader.Read(buffer, size);
+ return true;
+ };
+ return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
+ size_t DelimiterPos = header.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string Key = header.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string Value = header.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(header, "");
+ };
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ size_t RangeStartPos = RangeIt->second.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart =
+ ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength = RequestedRangeEnd.value() - 1;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ std::pair<std::string, std::string> Header = GetHeader(header);
+ if (Header.first == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > 1024 * 1024)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ std::pair<std::string, std::string> Header = GetHeader(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (Header.first == "Content-Range"sv)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile,
+ m_ConnectionSettings.RetryCount);
+
+ return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+}
+
+} // namespace zen \ No newline at end of file