// Copyright Epic Games, Inc. All Rights Reserved. #include "httpclientcpr.h" #include #include #include #include #include #include #include #include #include #include namespace zen { HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function&& CheckIfAbortFunction) { return new CprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); } static std::atomic HttpClientRequestIdCounter{0}; bool HttpClient::ErrorContext::IsConnectionError() const { switch (static_cast(ErrorCode)) { case cpr::ErrorCode::CONNECTION_FAILURE: case cpr::ErrorCode::OPERATION_TIMEDOUT: case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: return true; default: return false; } } // If we want to support different HTTP client implementations then we'll need to make this more abstract HttpClientError::ResponseClass HttpClientError::GetResponseClass() const { if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) { switch ((cpr::ErrorCode)m_Error) { case cpr::ErrorCode::CONNECTION_FAILURE: return ResponseClass::kHttpCantConnectError; case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: return ResponseClass::kHttpNoHost; case cpr::ErrorCode::INTERNAL_ERROR: case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: case cpr::ErrorCode::NETWORK_SEND_FAILURE: case cpr::ErrorCode::OPERATION_TIMEDOUT: return ResponseClass::kHttpTimeout; case cpr::ErrorCode::SSL_CONNECT_ERROR: case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: case cpr::ErrorCode::SSL_CACERT_ERROR: case cpr::ErrorCode::GENERIC_SSL_ERROR: return ResponseClass::kHttpSLLError; default: return ResponseClass::kHttpOtherClientError; } } else if (IsHttpSuccessCode(m_ResponseCode)) { return ResponseClass::kSuccess; } else { switch (m_ResponseCode) { case HttpResponseCode::Unauthorized: return ResponseClass::kHttpUnauthorized; case HttpResponseCode::NotFound: return ResponseClass::kHttpNotFound; case HttpResponseCode::Forbidden: return ResponseClass::kHttpForbidden; case HttpResponseCode::Conflict: return ResponseClass::kHttpConflict; case HttpResponseCode::InternalServerError: return ResponseClass::kHttpInternalServerError; case HttpResponseCode::ServiceUnavailable: return ResponseClass::kHttpServiceUnavailable; case HttpResponseCode::BadGateway: return ResponseClass::kHttpBadGateway; case HttpResponseCode::GatewayTimeout: return ResponseClass::kHttpGatewayTimeout; default: if (m_ResponseCode >= HttpResponseCode::InternalServerError) { return ResponseClass::kHttpOtherServerError; } else { return ResponseClass::kHttpOtherClientError; } } } } ////////////////////////////////////////////////////////////////////////// // // CPR helpers static cpr::Body AsCprBody(const CbObject& Obj) { return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); } static cpr::Body AsCprBody(const IoBuffer& Obj) { return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); } static bool ShouldRetry(const cpr::Response& Response) { switch (Response.error.code) { case cpr::ErrorCode::OK: break; case cpr::ErrorCode::INTERNAL_ERROR: case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: case cpr::ErrorCode::NETWORK_SEND_FAILURE: case cpr::ErrorCode::OPERATION_TIMEDOUT: return true; default: return false; } switch ((HttpResponseCode)Response.status_code) { case HttpResponseCode::RequestTimeout: case HttpResponseCode::TooManyRequests: case HttpResponseCode::InternalServerError: case HttpResponseCode::BadGateway: case HttpResponseCode::ServiceUnavailable: case HttpResponseCode::GatewayTimeout: return true; default: return false; } }; static std::pair HeaderContentType(ZenContentType ContentType) { return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); } ////////////////////////////////////////////////////////////////////////// CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings, std::function&& CheckIfAbortFunction) : HttpClientBase(BaseUri, Connectionsettings, std::move(CheckIfAbortFunction)) { } CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); m_SessionLock.WithExclusiveLock([&] { for (auto CprSession : m_Sessions) { delete CprSession; } m_Sessions.clear(); }); } HttpClient::Response CprHttpClient::ResponseWithPayload(std::string_view SessionId, cpr::Response&& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload, std::vector&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); ResponseBuffer.SetContentType(ContentType); } const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); if (!Quiet) { if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } } std::sort(BoundaryPositions.begin(), BoundaryPositions.end(), [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { return Lhs.RangeOffset < Rhs.RangeOffset; }); return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload, std::vector&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) { const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); if (!Quiet) { if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) { ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); } } // Client side failure code return HttpClient::Response{ .StatusCode = WorkResponseCode, .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}}; } if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) { return HttpClient::Response{.StatusCode = WorkResponseCode, .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed}; } else { return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } bool CprHttpClient::ValidatePayload(cpr::Response& Response, std::unique_ptr& PayloadFile) { ZEN_TRACE_CPU("ValidatePayload"); IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) { std::optional ExpectedContentSize = ParseInt(ContentLength->second); if (!ExpectedContentSize.has_value()) { Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); return false; } if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) { Response.error = cpr::Error( /*CURLE_READ_ERROR*/ 26, fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); return false; } } if (Response.status_code == (long)HttpResponseCode::PartialContent) { return true; } if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) { IoHash ExpectedPayloadHash; if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) { IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); if (PayloadHash != ExpectedPayloadHash) { Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", PayloadHash.ToHexString(), ExpectedPayloadHash.ToHexString())); return false; } } } if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) { if (ContentType->second == "application/x-ue-comp") { IoHash RawHash; uint64_t RawSize; if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize, /*OutOptionalTotalCompressedSize*/ nullptr)) { return true; } else { Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); return false; } } if (ContentType->second == "application/x-ue-cb") { if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); Error == CbValidateError::None) { return true; } else { Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); return false; } } } return true; } cpr::Response CprHttpClient::DoWithRetry(std::string_view SessionId, std::function&& Func, std::function&& Validate) { uint8_t Attempt = 0; cpr::Response Result = Func(); while (Attempt < m_ConnectionSettings.RetryCount) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return Result; } if (!ShouldRetry(Result)) { if (Result.error || !IsHttpSuccessCode(Result.status_code)) { break; } if (Validate(Result)) { break; } } Sleep(100 * (Attempt + 1)); Attempt++; const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); if (!Quiet) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), Attempt, m_ConnectionSettings.RetryCount + 1); } Result = Func(); } return Result; } cpr::Response CprHttpClient::DoWithRetry(std::string_view SessionId, std::function&& Func, std::unique_ptr& PayloadFile) { uint8_t Attempt = 0; cpr::Response Result = Func(); while (Attempt < m_ConnectionSettings.RetryCount) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return Result; } if (!ShouldRetry(Result)) { if (Result.error || !IsHttpSuccessCode(Result.status_code)) { break; } if (ValidatePayload(Result, PayloadFile)) { break; } } Sleep(100 * (Attempt + 1)); Attempt++; const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); if (!Quiet) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), Attempt, m_ConnectionSettings.RetryCount + 1); } Result = Func(); } return Result; } ////////////////////////////////////////////////////////////////////////// CprHttpClient::Session CprHttpClient::AllocSession(const std::string_view BaseUrl, const std::string_view ResourcePath, const HttpClientSettings& ConnectionSettings, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters, const std::string_view SessionId, std::optional AccessToken) { ZEN_TRACE_CPU("CprHttpClient::AllocSession"); cpr::Session* CprSession = nullptr; m_SessionLock.WithExclusiveLock([&] { if (!m_Sessions.empty()) { CprSession = m_Sessions.back(); m_Sessions.pop_back(); } }); if (CprSession == nullptr) { CprSession = new cpr::Session(); CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); CprSession->SetTimeout(ConnectionSettings.Timeout); if (ConnectionSettings.AssumeHttp2) { CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); } if (ConnectionSettings.Verbose) { // CprSession->SetVerbose(cpr::Verbose{ true }); CprSession->SetDebugCallback(cpr::DebugCallback{ [this](cpr::DebugCallback::InfoType type, std::string data, intptr_t userdata) { cpr::Session* CprSession = (cpr::Session*)userdata; ZEN_UNUSED(CprSession); switch (type) { case cpr::DebugCallback::InfoType::TEXT: if (data.find("need more data"sv) == std::string::npos) { ZEN_INFO("TEXT: {}", data); } break; case cpr::DebugCallback::InfoType::HEADER_IN: ZEN_INFO("HIN : {}", data); break; case cpr::DebugCallback::InfoType::HEADER_OUT: if (std::string::size_type TokenPos = data.find("Authorization: Bearer "sv); TokenPos != std::string::npos) { TokenPos += 22; std::string::size_type TokenEndPos = data.find_first_of("\r\n", TokenPos); if (TokenEndPos == std::string::npos) { TokenEndPos = data.length(); } std::string Copy = data; Copy.replace(Copy.begin() + TokenPos, Copy.begin() + TokenEndPos, fmt::format("[{} char token]", TokenEndPos - TokenPos)); ZEN_INFO("HOUT: {}", Copy); } else { ZEN_INFO("HOUT: {}", data); } break; case cpr::DebugCallback::InfoType::DATA_IN: // ZEN_INFO("DATA_IN: {}", data); break; case cpr::DebugCallback::InfoType::DATA_OUT: // ZEN_INFO("DATA_OUT: {}", data); break; case cpr::DebugCallback::InfoType::SSL_DATA_IN: // ZEN_INFO("SSL_DATA_IN: {}", data); break; case cpr::DebugCallback::InfoType::SSL_DATA_OUT: // ZEN_INFO("SSL_DATA_OUT: {}", data); break; } }, (intptr_t)CprSession}); } } if (!AdditionalHeader->empty()) { CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); } if (!SessionId.empty()) { CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); } if (AccessToken) { CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); } if (!Parameters->empty()) { cpr::Parameters Tmp; for (auto It = Parameters->begin(); It != Parameters->end(); It++) { Tmp.Add({It->first, It->second}); } CprSession->SetParameters(Tmp); } else { CprSession->SetParameters({}); } ExtendableStringBuilder<128> UrlBuffer; UrlBuffer << BaseUrl << ResourcePath; CprSession->SetUrl(UrlBuffer.c_str()); return Session(this, CprSession); } void CprHttpClient::ReleaseSession(cpr::Session* CprSession) { ZEN_TRACE_CPU("CprHttpClient::ReleaseSession"); CprSession->SetUrl({}); CprSession->SetHeader({}); CprSession->SetBody({}); m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); } CprHttpClient::Response CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::TransactPackage"); Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); // First, list of offered chunks for filtering on the server end std::vector AttachmentsToSend; std::span Attachments = Package.GetAttachments(); const uint32_t RequestId = ++HttpClientRequestIdCounter; auto RequestIdString = fmt::to_string(RequestId); if (Attachments.empty() == false) { CbObjectWriter Writer; Writer.BeginArray("offer"); for (const CbAttachment& Attachment : Attachments) { Writer.AddHash(Attachment.GetHash()); } Writer.EndArray(); BinaryWriter MemWriter; Writer.Save(MemWriter); Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); cpr::Response FilterResponse = Sess.Post(); if (FilterResponse.status_code == 200) { IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); CbValidateError ValidationError = CbValidateError::None; if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); ValidationError == CbValidateError::None) { for (CbFieldView& Entry : ResponseObject["need"]) { ZEN_ASSERT(Entry.IsHash()); AttachmentsToSend.push_back(Entry.AsHash()); } } } } // Prepare package for send CbPackage SendPackage; SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); for (const IoHash& AttachmentCid : AttachmentsToSend) { const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); if (Attachment) { SendPackage.AddAttachment(*Attachment); } else { // This should be an error -- server asked to have something we can't find } } // Transmit package payload CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); SharedBuffer FlatMessage = Message.Flatten(); Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); cpr::Response FilterResponse = Sess.Post(); if (!IsHttpSuccessCode(FilterResponse.status_code)) { return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; } IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) { HttpContentType ContentType = ParseContentType(It->second); ResponseBuffer.SetContentType(ContentType); } return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; } ////////////////////////////////////////////////////////////////////////// // // Standard HTTP verbs // CprHttpClient::Response CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Put"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->SetBody(AsCprBody(Payload)); Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); return Sess.Put(); }), {}); } CprHttpClient::Response CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) { ZEN_TRACE_CPU("CprHttpClient::Put"); return CommonResponse(m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, {{"Content-Length", "0"}}, Parameters, m_SessionId, GetAccessToken()); return Sess.Put(); }), {}); } CprHttpClient::Response CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { ZEN_TRACE_CPU("CprHttpClient::Get"); return CommonResponse( m_SessionId, DoWithRetry( m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); return Sess.Get(); }, [this](cpr::Response& Result) { std::unique_ptr NoTempFile; return ValidatePayload(Result, NoTempFile); }), {}); } CprHttpClient::Response CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Head"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); return Sess.Head(); }), {}); } CprHttpClient::Response CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Delete"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); return Sess.Delete(); }), {}); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { ZEN_TRACE_CPU("CprHttpClient::PostNoPayload"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); return Sess.Post(); }), {}); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) { return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::PostWithPayload"); return CommonResponse( m_SessionId, DoWithRetry( m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->UpdateHeader({HeaderContentType(ContentType)}); IoBufferFileReference FileRef = {nullptr, 0, 0}; if (Payload.GetFileReference(FileRef)) { uint64_t Offset = 0; detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { size = Min(size, Payload.GetSize() - Offset); Buffer.Read(buffer, size); Offset += size; return true; }; return Sess.Post(cpr::ReadCallback(gsl::narrow(Payload.GetSize()), ReadCallback)); } Sess->SetBody(AsCprBody(Payload)); return Sess.Post(); }), {}); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->SetBody(AsCprBody(Payload)); Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); return Sess.Post(); }), {}); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) { return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); } CprHttpClient::Response CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Post"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->UpdateHeader({HeaderContentType(ContentType)}); detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); auto ReadCallback = [this, &Reader](char* buffer, size_t& size, intptr_t) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return false; } size = Reader.Read(buffer, size); return true; }; return Sess.Post(cpr::ReadCallback(gsl::narrow(Payload.GetSize()), ReadCallback)); }), {}); } CprHttpClient::Response CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Upload"); return CommonResponse( m_SessionId, DoWithRetry( m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); IoBufferFileReference FileRef = {nullptr, 0, 0}; if (Payload.GetFileReference(FileRef)) { uint64_t Offset = 0; detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); auto ReadCallback = [this, &Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return false; } size = Min(size, Payload.GetSize() - Offset); Buffer.Read(buffer, size); Offset += size; return true; }; return Sess.Put(cpr::ReadCallback(gsl::narrow(Payload.GetSize()), ReadCallback)); } Sess->SetBody(AsCprBody(Payload)); return Sess.Put(); }), {}); } CprHttpClient::Response CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Upload"); return CommonResponse( m_SessionId, DoWithRetry(m_SessionId, [&]() { Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Sess->UpdateHeader({HeaderContentType(ContentType)}); detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); auto ReadCallback = [this, &Reader](char* buffer, size_t& size, intptr_t) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return false; } size = Reader.Read(buffer, size); return true; }; return Sess.Put(cpr::ReadCallback(gsl::narrow(Payload.GetSize()), ReadCallback)); }), {}); } CprHttpClient::Response CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CprHttpClient::Download"); std::string PayloadString; std::unique_ptr PayloadFile; HttpContentType ContentType = HttpContentType::kUnknownContentType; detail::MultipartBoundaryParser BoundaryParser; bool IsMultiRangeResponse = false; cpr::Response Response = DoWithRetry( m_SessionId, [&]() { auto DownloadCallback = [&](std::string data, intptr_t) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { return false; } if (IsMultiRangeResponse) { BoundaryParser.ParseInput(data); } if (PayloadFile) { ZEN_ASSERT(PayloadString.empty()); std::error_code Ec = PayloadFile->Write(data); if (Ec) { ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", TempFolderPath.string(), Ec.message()); return false; } } else { PayloadString.append(data); } return true; }; uint64_t RequestedContentLength = (uint64_t)-1; if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) { if (RangeIt->second.starts_with("bytes")) { std::string_view RangeValue(RangeIt->second); size_t RangeStartPos = RangeValue.find('=', 5); if (RangeStartPos != std::string::npos) { RangeStartPos++; while (RangeValue[RangeStartPos] == ' ') { RangeStartPos++; } RequestedContentLength = 0; while (RangeStartPos < RangeValue.length()) { size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); if (RangeEnd == std::string::npos) { RangeEnd = RangeValue.length(); } std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); size_t RangeSplitPos = RangeString.find('-'); if (RangeSplitPos != std::string::npos) { std::optional RequestedRangeStart = ParseInt(RangeString.substr(0, RangeSplitPos)); std::optional RequestedRangeEnd = ParseInt(RangeString.substr(RangeSplitPos + 1)); if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) { RequestedContentLength += RequestedRangeEnd.value() - 1; } } RangeStartPos = RangeEnd; while (RangeStartPos != RangeValue.length() && (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) { RangeStartPos++; } } } } } cpr::Response Response; { std::vector> ReceivedHeaders; auto HeaderCallback = [&](std::string header, intptr_t) { const std::pair Header = detail::GetHeaderKeyAndValue(header); if (Header.first == "Content-Length"sv) { std::optional ContentLength = ParseInt(Header.second); if (ContentLength.has_value()) { if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) { PayloadFile = std::make_unique(); std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); if (Ec) { ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", TempFolderPath.string(), Ec.message()); PayloadFile.reset(); } } else { PayloadString.reserve(ContentLength.value()); } } } else if (Header.first == "Content-Type") { IsMultiRangeResponse = BoundaryParser.Init(Header.second); if (!IsMultiRangeResponse) { ContentType = ParseContentType(Header.second); } } else if (Header.first == "Content-Range") { if (!IsMultiRangeResponse) { std::pair Range = detail::ParseContentRange(Header.second); if (Range.second != 0) { BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, .RangeOffset = Range.first, .RangeLength = Range.second, .ContentType = ContentType}); } } } if (!Header.first.empty()) { ReceivedHeaders.emplace_back(std::move(Header)); } return 1; }; Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); for (const std::pair& H : ReceivedHeaders) { Response.header.insert_or_assign(H.first, H.second); } } if (m_ConnectionSettings.AllowResume) { auto SupportsRanges = [](const cpr::Response& Response) -> bool { if (Response.header.find("Content-Range") != Response.header.end()) { return true; } if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) { return It->second == "bytes"sv; } return false; }; auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { if (IsMultiRangeResponse) { return false; } if (ShouldRetry(Response)) { return SupportsRanges(Response); } return false; }; if (ShouldResume(Response)) { auto It = Response.header.find("Content-Length"); if (It != Response.header.end()) { uint64_t ContentLength = RequestedContentLength; if (ContentLength == uint64_t(-1)) { if (auto ParsedContentLength = ParseInt(It->second); ParsedContentLength.has_value()) { ContentLength = ParsedContentLength.value(); } } std::vector> ReceivedHeaders; auto HeaderCallback = [&](std::string header, intptr_t) { const std::pair Header = detail::GetHeaderKeyAndValue(header); if (!Header.first.empty()) { ReceivedHeaders.emplace_back(std::move(Header)); } if (Header.first == "Content-Range"sv) { if (Header.second.starts_with("bytes "sv)) { size_t RangeStartEnd = Header.second.find('-', 6); if (RangeStartEnd != std::string::npos) { const auto Start = ParseInt(Header.second.substr(6, RangeStartEnd - 6)); if (Start) { uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); if (Start.value() == DownloadedSize) { return 1; } else if (Start.value() > DownloadedSize) { return 0; } if (PayloadFile) { PayloadFile->ResetWritePos(Start.value()); } else { PayloadString = PayloadString.substr(0, Start.value()); } return 1; } } } return 0; } return 1; }; KeyValueMap HeadersWithRange(AdditionalHeader); do { uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) { if (RangeIt->second == Range) { // If we didn't make any progress, abort break; } } HeadersWithRange.Entries.insert_or_assign("Range", Range); Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); for (const std::pair& H : ReceivedHeaders) { Response.header.insert_or_assign(H.first, H.second); } ReceivedHeaders.clear(); } while (ShouldResume(Response)); } } } if (!PayloadString.empty()) { Response.text = std::move(PayloadString); } return Response; }, PayloadFile); return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, std::move(BoundaryParser.Boundaries)); } } // namespace zen