diff options
| author | Stefan Boberg <[email protected]> | 2025-09-30 19:07:51 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-09-30 19:07:51 +0200 |
| commit | 634181a04efff90def7a98d98eac7078e1d4e62d (patch) | |
| tree | 04678bba636a76d21f300ff6e73af4473274cf12 /src/zenhttp/clients/httpclientcpr.cpp | |
| parent | use batching clang-format for quicker turnaround on validate actions (#529) (diff) | |
| download | zen-634181a04efff90def7a98d98eac7078e1d4e62d.tar.xz zen-634181a04efff90def7a98d98eac7078e1d4e62d.zip | |
HttpClient support for pluggable back-ends (#532)
refactored HttpClient to separate out cpr implementation into separate classes, with an abstract base class to allow plugging in multiple implementations in the future
Diffstat (limited to 'src/zenhttp/clients/httpclientcpr.cpp')
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 1035 |
1 files changed, 1035 insertions, 0 deletions
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp new file mode 100644 index 000000000..568106887 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -0,0 +1,1035 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcpr.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/packageformat.h> + +namespace zen { + +HttpClientBase* +CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +{ + return new CprHttpClient(BaseUri, ConnectionSettings); +} + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +// If we want to support different HTTP client implementations then we'll need to make this more abstract + +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) + { + switch ((cpr::ErrorCode)m_Error) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + return ResponseClass::kHttpCantConnectError; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return ResponseClass::kHttpNoHost; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return ResponseClass::kHttpTimeout; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_CACERT_ERROR: + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// CPR helpers + +static cpr::Body +AsCprBody(const CbObject& Obj) +{ + return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); +} + +static cpr::Body +AsCprBody(const IoBuffer& Obj) +{ + return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); +} + +////////////////////////////////////////////////////////////////////////// + +static HttpClient::Response +ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) +{ + // This ends up doing a memcpy, would be good to get rid of it by streaming results + // into buffer directly + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); + + if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) + { + const HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; +} + +static HttpClient::Response +CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); + if (HttpResponse.error) + { + if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && + HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) + { + ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); + } + + // Client side failure code + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed, + .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), + .ErrorMessage = HttpResponse.error.message}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; + } + else + { + return ResponseWithPayload( + SessionId, + HttpResponse, + WorkResponseCode, + Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size())); + } +} + +static bool +ShouldRetry(const cpr::Response& Response) +{ + switch (Response.error.code) + { + case cpr::ErrorCode::OK: + break; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch ((HttpResponseCode)Response.status_code) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +}; + +static bool +ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + + if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second); + if (!ExpectedContentSize.has_value()) + { + Response.error = + cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Response.error = cpr::Error( + /*CURLE_READ_ERROR*/ 26, + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); + return false; + } + } + + if (Response.status_code == (long)HttpResponseCode::PartialContent) + { + return true; + } + + if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, + fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString())); + return false; + } + } + } + + if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) + { + if (ContentType->second == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize)) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); + return false; + } + } + if (ContentType->second == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); + return false; + } + } + } + + return true; +} + +static cpr::Response +DoWithRetry( + std::string_view SessionId, + std::function<cpr::Response()>&& Func, + uint8_t RetryCount, + std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static cpr::Response +DoWithRetry(std::string_view SessionId, + std::function<cpr::Response()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile, + uint8_t RetryCount) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) +: HttpClientBase(BaseUri, Connectionsettings) +{ +} + +CprHttpClient::~CprHttpClient() +{ + ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto CprSession : m_Sessions) + { + delete CprSession; + } + m_Sessions.clear(); + }); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::Session +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_TRACE_CPU("CprHttpClient::AllocSession"); + cpr::Session* CprSession = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + CprSession = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (CprSession == nullptr) + { + CprSession = new cpr::Session(); + CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); + CprSession->SetTimeout(ConnectionSettings.Timeout); + if (ConnectionSettings.AssumeHttp2) + { + CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); + } + } + + if (!AdditionalHeader->empty()) + { + CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); + } + if (!SessionId.empty()) + { + CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); + } + if (AccessToken) + { + CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + } + if (!Parameters->empty()) + { + cpr::Parameters Tmp; + for (auto It = Parameters->begin(); It != Parameters->end(); It++) + { + Tmp.Add({It->first, It->second}); + } + CprSession->SetParameters(Tmp); + } + else + { + CprSession->SetParameters({}); + } + + ExtendableStringBuilder<128> UrlBuffer; + UrlBuffer << BaseUrl << ResourcePath; + CprSession->SetUrl(UrlBuffer.c_str()); + + return Session(this, CprSession); +} + +void +CprHttpClient::ReleaseSession(cpr::Session* CprSession) +{ + ZEN_TRACE_CPU("CprHttpClient::ReleaseSession"); + CprSession->SetUrl({}); + CprSession->SetHeader({}); + CprSession->SetBody({}); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); +} + +CprHttpClient::Response +CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::TransactPackage"); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, + Url, + m_ConnectionSettings, + {{"Content-Length", "0"}}, + Parameters, + m_SessionId, + GetAccessToken()); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Get(); + }, + m_ConnectionSettings.RetryCount, + [](cpr::Response& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + })); +} + +CprHttpClient::Response +CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Head(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Delete(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { + size_t DelimiterPos = header.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string Key = header.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string Value = header.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + + return std::make_pair(Key, Value); + } + return std::make_pair(header, ""); + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + size_t RangeStartPos = RangeIt->second.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = + ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength = RequestedRangeEnd.value() - 1; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > 1024 * 1024) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile, + m_ConnectionSettings.RetryCount); + + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); +} + +} // namespace zen
\ No newline at end of file |